Reproducing Kingma & Welling (2013), Fig. 2 (MNIST VAE)#
This example reproduces Figure 2 from the paper on Auto-Encoding Variational Bayes (Kingma & Welling (2013)). The experiment trains Variational Autoencoders (VAEs) with different latent dimensionalities \(N_z\) on MNIST and tracks the Evidence Lower Bound (ELBO) as a function of the number of training samples processed.
The setup follows the configuration in the paper:
one-hidden-layer MLP encoder and decoder with Tanh activations;
hidden size \(H = 500\);
mini-batch size \(M = 100\);
learning rate selected from \(\{0.01, 0.02, 0.1\}\) (here:
0.02);small weight decay (approximate \(\mathcal{N}(0, I)\) prior on weights);
\(L = 1\) Monte Carlo sample for the latent variable estimator.
Configuration#
We begin with imports, global settings, and the reproducibility utilities. This includes the list of latent dimensions used for \(N_z\), hyperparameters, and device configuration.
1"""
2Reproduce the MNIST VAE experiment from Kingma & Welling (2013), Fig. 2.
3
4We train VAEs with different latent dimensionalities N_z and track the ELBO
5as a function of the number of training samples seen.
6The setup follows the original paper:
7
8- One hidden layer MLPs with Tanh activations in encoder/decoder
9- Hidden size H = 500
10- Mini-batch size M = 100
11- Learning rate selected from {0.01, 0.02, 0.1} (we use 0.02 here)
12- L = 1 Monte Carlo sample for the stochastic latent variable
13"""
14
15from __future__ import annotations
16
17import random
18import time
19from pathlib import Path
20from typing import Dict, List, Tuple
21
22import matplotlib.pyplot as plt
23import numpy as np
24import torch
25import torch.nn as nn
26from matplotlib.ticker import (
27 LogFormatterMathtext,
28 LogLocator,
29 NullFormatter,
30)
31from torch.utils.data import DataLoader
32from torchvision import datasets, transforms
33
34from pyautoencoder.variational import VAE
35
36
37# ---------------- Configuration ---------------- #
38LATENT_DIMS: List[int] = [3, 5, 10, 20, 200] # N_z values (paper)
39HIDDEN_SIZE: int = 500 # hidden layer size (encoder/decoder, MNIST)
40BATCH_SIZE: int = 100 # M = 100 (paper)
41LEARNING_RATE: float = 0.02 # chosen from {0.01, 0.02, 0.1} (paper)
42MC_SAMPLES: int = 1 # L = 1 (paper)
43TARGET_TRAIN_SAMPLES: int = int(1e7) # stop after this many training samples
44EVAL_EVERY_SAMPLES: int = int(1e5) # evaluate and log every this many samples
45USE_STOCHASTIC_BINARIZATION: bool = False # binarize x ~ Bernoulli(p = p_x) if True
46
47DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48SEED: int = 1926
49
50
51# ---------------- Utilities ------------------- #
52def set_seed(seed: int) -> None:
53 random.seed(seed)
54 np.random.seed(seed)
55 torch.manual_seed(seed)
56 if torch.cuda.is_available():
57 torch.cuda.manual_seed(seed)
58
59
60def init_weights_small_normal(module: nn.Module) -> None:
61 if isinstance(module, nn.Linear):
62 nn.init.normal_(module.weight, mean=0.0, std=0.01)
63 if module.bias is not None:
64 nn.init.zeros_(module.bias)
Key parameters:
LATENT_DIMS– list of latent sizes \(N_z\) evaluated.HIDDEN_SIZE– hidden layer size in encoder and decoder.MC_SAMPLES– number of Monte Carlo samples (\(L=1\) in the paper).set_seed()– deterministic seeding for reproducibility.
Data Loading#
The MNIST dataset is loaded through torchvision.datasets.
Optionally, inputs may be stochastically binarized as in the original
AEVB experiments.
1def make_dataloaders(
2 batch_size: int,
3 use_stochastic_binarization: bool = USE_STOCHASTIC_BINARIZATION,
4) -> Tuple[DataLoader, DataLoader]:
5
6 tfms: List[transforms.Compose | transforms.ToTensor] = [transforms.ToTensor()]
7
8 if use_stochastic_binarization:
9 class BernoulliBinarize:
10 def __call__(self, x: torch.Tensor) -> torch.Tensor:
11 # Sample x ~ Bernoulli(p = x); assumes x in [0, 1].
12 return torch.bernoulli(x)
13
14 tfms.append(BernoulliBinarize()) # type: ignore
15
16 transform = transforms.Compose(tfms)
17
18 train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
19 test_dataset = datasets.MNIST("./data", train=False, download=True, transform=transform)
20
21 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
22 test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
23 return train_loader, test_loader
The make_dataloaders() function returns training and test dataloaders
with appropriate preprocessing and optional binarization.
Model Definition#
Each VAE consists of:
a single-hidden-layer encoder mapping \(x \mapsto (\mu, \log \sigma^2)\),
a single-hidden-layer decoder mapping \(z \mapsto \hat{x}\),
Tanh nonlinearities,
linear output logits interpreted by
VAE.compute_loss()under a Bernoulli likelihood.
Weights are initialized with a small Gaussian as described in the paper.
1def make_vae(latent_dim: int, hidden: int = HIDDEN_SIZE) -> VAE:
2 encoder = nn.Sequential(
3 nn.Flatten(),
4 nn.Linear(28 * 28, hidden),
5 nn.Tanh(),
6 )
7
8 decoder = nn.Sequential(
9 nn.Linear(latent_dim, hidden),
10 nn.Tanh(),
11 nn.Linear(hidden, 28 * 28),
12 nn.Unflatten(-1, (1, 28, 28)), # keep last layer linear; VAELoss(bernoulli) expects logits
13 )
14
15 model = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
16 model.build(input_sample=torch.randn(1, 1, 28, 28))
17 model.apply(init_weights_small_normal)
18 return model
The model is built explicitly via VAE.build(), which infers
dimension-dependent components from a representative input sample.
ELBO Evaluation#
We evaluate the ELBO over a dataloader by summing the batch-wise ELBO
diagnostics returned by VAE.compute_loss().
1@torch.no_grad()
2def average_elbo(dataloader: DataLoader, model: VAE) -> float:
3 model.eval()
4 total_elbo = 0.0
5 n = 0
6
7 for x, _ in dataloader:
8 x = x.to(DEVICE)
9 out = model(x, S=MC_SAMPLES)
10 loss_info = model.compute_loss(x, out, beta=1, likelihood='bernoulli')
11 elbo_batch = loss_info.diagnostics['elbo']
12 batch_size = x.size(0)
13 total_elbo += elbo_batch * batch_size
14 n += batch_size
15
16 return total_elbo / n
This routine is used during training to record both training and test ELBO values.
Training Loop#
For each latent dimension \(N_z\), the VAE is trained until a fixed
number of training samples have been processed
(TARGET_TRAIN_SAMPLES). Periodic evaluations of the ELBO on both
train and test sets are logged.
1def train_one_setting(
2 latent_dim: int,
3 train_loader: DataLoader,
4 test_loader: DataLoader,
5) -> Tuple[VAE, List[Dict[str, float]]]:
6 model = make_vae(latent_dim).to(DEVICE)
7 optimizer = torch.optim.Adagrad(
8 model.parameters(),
9 lr=LEARNING_RATE,
10 )
11
12 logs: List[Dict[str, float]] = []
13 samples_seen = 0
14 next_eval = EVAL_EVERY_SAMPLES
15
16 start_time = time.time()
17 while samples_seen < TARGET_TRAIN_SAMPLES:
18 for x, _ in train_loader:
19 x = x.to(DEVICE)
20 model.train()
21 optimizer.zero_grad()
22
23 out = model(x, S=MC_SAMPLES)
24 loss_info = model.compute_loss(x, out, beta=1, likelihood='bernoulli')
25 loss_info.objective.backward()
26 optimizer.step()
27
28 batch_size = x.size(0)
29 samples_seen += batch_size
30
31 if samples_seen >= next_eval:
32 train_elbo = average_elbo(train_loader, model)
33 test_elbo = average_elbo(test_loader, model)
34 logs.append(
35 {
36 "samples": float(samples_seen),
37 "train_elbo": float(train_elbo),
38 "test_elbo": float(test_elbo),
39 }
40 )
41 elapsed = time.time() - start_time
42 print(
43 f"N_z={latent_dim:3d} | "
44 f"samples={samples_seen:>9d} | "
45 f"ELBO_train={train_elbo:.2f}, "
46 f"ELBO_test={test_elbo:.2f} | "
47 f"(elapsed {elapsed:.1f}s)"
48 )
49 next_eval += EVAL_EVERY_SAMPLES
50
51 if samples_seen >= TARGET_TRAIN_SAMPLES:
52 break
53
54 return model, logs
Each training step:
Draws \(S\) Monte Carlo samples from \(q(z \mid x)\),
Computes the negative ELBO via
VAE.compute_loss()withlikelihood='bernoulli',Backpropagates gradients through encoder and decoder,
Records evaluation metrics periodically.
The ELBO diagnostics contain:
elbo– Evidence Lower Bound value,log_likelihood– batch-mean reconstruction term,kl_divergence– batch-mean KL divergence.
Plotting the Results#
After all training runs complete, the ELBO curves are plotted as a function of the number of processed samples (log-scaled), following the appearance of Fig. 2 from Kingma & Welling (2013).
1def plot_elbo_curves(all_logs: Dict[int, List[Dict[str, float]]]) -> Path:
2 num_settings = len(all_logs)
3 fig, axes = plt.subplots(1, num_settings, figsize=(3.0 * num_settings, 3.0))
4
5 # Ensure axes is iterable even when num_settings == 1
6 if num_settings == 1:
7 axes = np.array([axes])
8
9 latent_dims_sorted = sorted(all_logs.keys())
10
11 for i, nz in enumerate(latent_dims_sorted):
12 ax = axes[i]
13 xs = [entry["samples"] for entry in all_logs[nz]]
14 ys_train = [entry["train_elbo"] for entry in all_logs[nz]]
15 ys_test = [entry["test_elbo"] for entry in all_logs[nz]]
16
17 ax.plot(xs, ys_train, label="AEVB (train)", color="r")
18 ax.plot(xs, ys_test, linestyle="--", label="AEVB (test)", color="r")
19
20 ax.set_xscale("log")
21 ax.set_ylim(-150, -95)
22 ax.set_title(f"MNIST, $N_z = {nz}$")
23
24 if i == 0:
25 ax.set_xlabel("# training samples evaluated")
26 ax.set_ylabel(r"$\mathcal{L}$ (ELBO)")
27
28 # Log-scale formatting on x-axis
29 ax.xaxis.set_major_locator(LogLocator(base=10))
30 ax.xaxis.set_major_formatter(LogFormatterMathtext(base=10))
31 ax.xaxis.set_minor_locator(LogLocator(base=10, subs=range(2, 10)))
32 ax.xaxis.set_minor_formatter(NullFormatter())
33
34 axes[0].legend(loc="lower right")
35 plt.tight_layout()
36
37 out_path = Path("vae_mnist_fig2_repro.png")
38 plt.savefig(out_path, dpi=200, bbox_inches="tight")
39 plt.close(fig)
40 print(f"Saved figure -> {out_path}")
41 return out_path
The figure is saved as vae_mnist_fig2_repro.png.
Entry Point#
The main() function orchestrates the full experiment: seeding,
data loading, training VAEs for each \(N_z\), logging results, and
plotting the final curves.
1def main() -> None:
2 print("=== Reproducing Kingma & Welling (2013) Fig. 2 on MNIST ===")
3 print(f"Using device: {DEVICE}")
4
5 set_seed(SEED)
6 train_loader, test_loader = make_dataloaders(BATCH_SIZE)
7
8 all_logs: Dict[int, List[Dict[str, float]]] = {}
9 for nz in LATENT_DIMS:
10 print(f"\n--- Training VAE with N_z = {nz} ---")
11 _, logs = train_one_setting(nz, train_loader, test_loader)
12 all_logs[nz] = logs
13
14 plot_elbo_curves(all_logs)
15 print("Done.")
The script can be executed directly:
1if __name__ == "__main__":
2 main()
Full Example Script#
For convenience, here is the complete script in one block:
1"""
2Reproduce the MNIST VAE experiment from Kingma & Welling (2013), Fig. 2.
3
4We train VAEs with different latent dimensionalities N_z and track the ELBO
5as a function of the number of training samples seen.
6The setup follows the original paper:
7
8- One hidden layer MLPs with Tanh activations in encoder/decoder
9- Hidden size H = 500
10- Mini-batch size M = 100
11- Learning rate selected from {0.01, 0.02, 0.1} (we use 0.02 here)
12- L = 1 Monte Carlo sample for the stochastic latent variable
13"""
14
15from __future__ import annotations
16
17import random
18import time
19from pathlib import Path
20from typing import Dict, List, Tuple
21
22import matplotlib.pyplot as plt
23import numpy as np
24import torch
25import torch.nn as nn
26from matplotlib.ticker import (
27 LogFormatterMathtext,
28 LogLocator,
29 NullFormatter,
30)
31from torch.utils.data import DataLoader
32from torchvision import datasets, transforms
33
34from pyautoencoder.variational import VAE
35
36
37# ---------------- Configuration ---------------- #
38LATENT_DIMS: List[int] = [3, 5, 10, 20, 200] # N_z values (paper)
39HIDDEN_SIZE: int = 500 # hidden layer size (encoder/decoder, MNIST)
40BATCH_SIZE: int = 100 # M = 100 (paper)
41LEARNING_RATE: float = 0.02 # chosen from {0.01, 0.02, 0.1} (paper)
42MC_SAMPLES: int = 1 # L = 1 (paper)
43TARGET_TRAIN_SAMPLES: int = int(1e7) # stop after this many training samples
44EVAL_EVERY_SAMPLES: int = int(1e5) # evaluate and log every this many samples
45USE_STOCHASTIC_BINARIZATION: bool = False # binarize x ~ Bernoulli(p = p_x) if True
46
47DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48SEED: int = 1926
49
50
51# ---------------- Utilities ------------------- #
52def set_seed(seed: int) -> None:
53 random.seed(seed)
54 np.random.seed(seed)
55 torch.manual_seed(seed)
56 if torch.cuda.is_available():
57 torch.cuda.manual_seed(seed)
58
59
60def init_weights_small_normal(module: nn.Module) -> None:
61 if isinstance(module, nn.Linear):
62 nn.init.normal_(module.weight, mean=0.0, std=0.01)
63 if module.bias is not None:
64 nn.init.zeros_(module.bias)
65
66
67# ---------------- Data ----------------------- #
68def make_dataloaders(
69 batch_size: int,
70 use_stochastic_binarization: bool = USE_STOCHASTIC_BINARIZATION,
71) -> Tuple[DataLoader, DataLoader]:
72
73 tfms: List[transforms.Compose | transforms.ToTensor] = [transforms.ToTensor()]
74
75 if use_stochastic_binarization:
76 class BernoulliBinarize:
77 def __call__(self, x: torch.Tensor) -> torch.Tensor:
78 # Sample x ~ Bernoulli(p = x); assumes x in [0, 1].
79 return torch.bernoulli(x)
80
81 tfms.append(BernoulliBinarize()) # type: ignore
82
83 transform = transforms.Compose(tfms)
84
85 train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
86 test_dataset = datasets.MNIST("./data", train=False, download=True, transform=transform)
87
88 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
89 test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
90 return train_loader, test_loader
91
92
93# ---------------- Model ---------------------- #
94def make_vae(latent_dim: int, hidden: int = HIDDEN_SIZE) -> VAE:
95 encoder = nn.Sequential(
96 nn.Flatten(),
97 nn.Linear(28 * 28, hidden),
98 nn.Tanh(),
99 )
100
101 decoder = nn.Sequential(
102 nn.Linear(latent_dim, hidden),
103 nn.Tanh(),
104 nn.Linear(hidden, 28 * 28),
105 nn.Unflatten(-1, (1, 28, 28)), # keep last layer linear; VAELoss(bernoulli) expects logits
106 )
107
108 model = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
109 model.build(input_sample=torch.randn(1, 1, 28, 28))
110 model.apply(init_weights_small_normal)
111 return model
112
113# ---------------- Evaluation ----------------- #
114@torch.no_grad()
115def average_elbo(dataloader: DataLoader, model: VAE) -> float:
116 model.eval()
117 total_elbo = 0.0
118 n = 0
119
120 for x, _ in dataloader:
121 x = x.to(DEVICE)
122 out = model(x, S=MC_SAMPLES)
123 loss_info = model.compute_loss(x, out, beta=1, likelihood='bernoulli')
124 elbo_batch = loss_info.diagnostics['elbo']
125 batch_size = x.size(0)
126 total_elbo += elbo_batch * batch_size
127 n += batch_size
128
129 return total_elbo / n
130
131
132# ---------------- Training ------------------- #
133def train_one_setting(
134 latent_dim: int,
135 train_loader: DataLoader,
136 test_loader: DataLoader,
137) -> Tuple[VAE, List[Dict[str, float]]]:
138 model = make_vae(latent_dim).to(DEVICE)
139 optimizer = torch.optim.Adagrad(
140 model.parameters(),
141 lr=LEARNING_RATE,
142 )
143
144 logs: List[Dict[str, float]] = []
145 samples_seen = 0
146 next_eval = EVAL_EVERY_SAMPLES
147
148 start_time = time.time()
149 while samples_seen < TARGET_TRAIN_SAMPLES:
150 for x, _ in train_loader:
151 x = x.to(DEVICE)
152 model.train()
153 optimizer.zero_grad()
154
155 out = model(x, S=MC_SAMPLES)
156 loss_info = model.compute_loss(x, out, beta=1, likelihood='bernoulli')
157 loss_info.objective.backward()
158 optimizer.step()
159
160 batch_size = x.size(0)
161 samples_seen += batch_size
162
163 if samples_seen >= next_eval:
164 train_elbo = average_elbo(train_loader, model)
165 test_elbo = average_elbo(test_loader, model)
166 logs.append(
167 {
168 "samples": float(samples_seen),
169 "train_elbo": float(train_elbo),
170 "test_elbo": float(test_elbo),
171 }
172 )
173 elapsed = time.time() - start_time
174 print(
175 f"N_z={latent_dim:3d} | "
176 f"samples={samples_seen:>9d} | "
177 f"ELBO_train={train_elbo:.2f}, "
178 f"ELBO_test={test_elbo:.2f} | "
179 f"(elapsed {elapsed:.1f}s)"
180 )
181 next_eval += EVAL_EVERY_SAMPLES
182
183 if samples_seen >= TARGET_TRAIN_SAMPLES:
184 break
185
186 return model, logs
187
188
189# ---------------- Plotting ------------------- #
190def plot_elbo_curves(all_logs: Dict[int, List[Dict[str, float]]]) -> Path:
191 num_settings = len(all_logs)
192 fig, axes = plt.subplots(1, num_settings, figsize=(3.0 * num_settings, 3.0))
193
194 # Ensure axes is iterable even when num_settings == 1
195 if num_settings == 1:
196 axes = np.array([axes])
197
198 latent_dims_sorted = sorted(all_logs.keys())
199
200 for i, nz in enumerate(latent_dims_sorted):
201 ax = axes[i]
202 xs = [entry["samples"] for entry in all_logs[nz]]
203 ys_train = [entry["train_elbo"] for entry in all_logs[nz]]
204 ys_test = [entry["test_elbo"] for entry in all_logs[nz]]
205
206 ax.plot(xs, ys_train, label="AEVB (train)", color="r")
207 ax.plot(xs, ys_test, linestyle="--", label="AEVB (test)", color="r")
208
209 ax.set_xscale("log")
210 ax.set_ylim(-150, -95)
211 ax.set_title(f"MNIST, $N_z = {nz}$")
212
213 if i == 0:
214 ax.set_xlabel("# training samples evaluated")
215 ax.set_ylabel(r"$\mathcal{L}$ (ELBO)")
216
217 # Log-scale formatting on x-axis
218 ax.xaxis.set_major_locator(LogLocator(base=10))
219 ax.xaxis.set_major_formatter(LogFormatterMathtext(base=10))
220 ax.xaxis.set_minor_locator(LogLocator(base=10, subs=range(2, 10)))
221 ax.xaxis.set_minor_formatter(NullFormatter())
222
223 axes[0].legend(loc="lower right")
224 plt.tight_layout()
225
226 out_path = Path("vae_mnist_fig2_repro.png")
227 plt.savefig(out_path, dpi=200, bbox_inches="tight")
228 plt.close(fig)
229 print(f"Saved figure -> {out_path}")
230 return out_path
231
232
233# ---------------- Main ---------------------- #
234def main() -> None:
235 print("=== Reproducing Kingma & Welling (2013) Fig. 2 on MNIST ===")
236 print(f"Using device: {DEVICE}")
237
238 set_seed(SEED)
239 train_loader, test_loader = make_dataloaders(BATCH_SIZE)
240
241 all_logs: Dict[int, List[Dict[str, float]]] = {}
242 for nz in LATENT_DIMS:
243 print(f"\n--- Training VAE with N_z = {nz} ---")
244 _, logs = train_one_setting(nz, train_loader, test_loader)
245 all_logs[nz] = logs
246
247 plot_elbo_curves(all_logs)
248 print("Done.")
249
250if __name__ == "__main__":
251 main()