MNIST Vanilla Autoencoder Example#
This example demonstrates how to train a fully connected autoencoder on
the MNIST dataset using pyautoencoder.
Configuration and Utilities#
We begin by importing the required libraries, defining the main hyperparameters, and setting up deterministic random seeds to ensure reproducibility.
1from __future__ import annotations
2
3from pathlib import Path
4import time
5import random
6
7import numpy as np
8import torch
9import torch.nn as nn
10from torch.utils.data import DataLoader
11from torchvision import datasets, transforms
12import matplotlib.pyplot as plt
13
14from pyautoencoder.vanilla import AE
15
16
17# ---------------- Config ---------------- #
18LATENT_DIM = 128
19NUM_EPOCHS = 30
20BATCH_SIZE = 128
21LEARNING_RATE = 1e-3
22NUM_RECON_COLS = 10
23
24DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25SEED = 1926
26
27# ---------------- Utils ----------------- #
28def set_seed(seed: int) -> None:
29 random.seed(seed)
30 np.random.seed(seed)
31 torch.manual_seed(seed)
32 if torch.cuda.is_available():
33 torch.cuda.manual_seed(seed)
Key elements in this section include:
LATENT_DIM,NUM_EPOCHS,BATCH_SIZE– the primary model and training hyperparameters,DEVICE– automatically selects"cuda"when available,set_seed()– helper function to provide reproducible runs.
Data Loading#
We prepare the MNIST training and test datasets using
torchvision.datasets.MNIST. A simple ToTensor transform is
used to map images to the \([0, 1]\) range.
1def make_dataloaders(batch_size: int) -> tuple[DataLoader, DataLoader]:
2 tfm = transforms.ToTensor() # maps to [0,1]
3 train_dataset = datasets.MNIST("./data", train=True, download=True, transform=tfm)
4 test_dataset = datasets.MNIST("./data", train=False, download=True, transform=tfm)
5
6 train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
7 test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
8 return train_dataloader, test_dataloader
The make_dataloaders() function returns two
torch.utils.data.DataLoader objects that supply batches to the
training and evaluation loops.
Model Definition#
The autoencoder consists of a compact fully connected encoder and
decoder. The encoder flattens the input and maps it into a latent
representation of dimension LATENT_DIM; the decoder reconstructs
the image from this latent vector.
1def make_autoencoder(latent_dim: int) -> AE:
2 encoder = nn.Sequential(
3 nn.Flatten(),
4 nn.Linear(28 * 28, 256),
5 nn.ReLU(),
6 nn.Linear(256, latent_dim),
7 )
8 decoder = nn.Sequential(
9 nn.Linear(latent_dim, 256),
10 nn.ReLU(),
11 nn.Linear(256, 28 * 28),
12 nn.Unflatten(-1, (1, 28, 28)), # keep last layer linear; AELoss(bernoulli) expects logits
13 )
14 model = AE(encoder=encoder, decoder=decoder)
15 model.build(input_sample=torch.randn(1, 1, 28, 28))
16 return model
Notes:
The final decoder layer is linear and produces logits.
The model is explicitly built through
AE.build(), which infers required shapes from a representative sample.Reconstruction loss is computed via the
AE.compute_loss()method, which uses a Bernoulli likelihood (interpreting the decoder output as logits).
Training Loop#
Training uses standard mini-batch gradient descent with the
torch.optim.Adam optimizer. During each iteration, we compute
the reconstruction log-likelihood and related diagnostics.
1def train(
2 model: AE,
3 train_loader: DataLoader,
4 num_epochs: int,
5 lr: float = LEARNING_RATE,
6) -> None:
7 model.to(DEVICE)
8 model.train()
9 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
10
11 start_time = time.time()
12 for epoch in range(1, num_epochs + 1):
13 running_log_likelihood = 0.0
14 n = 0
15
16 for x, _ in train_loader:
17 x = x.to(DEVICE)
18 optimizer.zero_grad()
19
20 out = model(x)
21 loss_info = model.compute_loss(x, out, likelihood='bernoulli')
22 loss_info.objective.backward()
23 optimizer.step()
24
25 batch_size = x.size(0)
26 running_log_likelihood += loss_info.diagnostics['log_likelihood'] * batch_size
27 n += batch_size
28
29 avg_NLL = running_log_likelihood / n
30 elapsed = time.time() - start_time
31
32 print(
33 f"Epoch {epoch:3d}/{num_epochs} "
34 f"LogLikelihood={avg_NLL:.4f} | "
35 f"(elapsed {elapsed:.1f}s)"
36 )
For each batch:
The model produces latent codes and reconstructions:
out = model(x).AE.compute_loss()returns aLossResultcontaining:objective– batch-mean negative log-likelihood (NLL) in nats,diagnostics['log_likelihood']– batch-mean log-likelihood (negative of objective).
The loss is backpropagated through the entire model.
These quantities are accumulated over the epoch and reported in the log.
Visualizing Reconstructions#
After training completes, we visualize a few randomly chosen test images alongside their reconstructions. The decoder outputs logits, so we apply a sigmoid to convert them to pixel intensities suitable for display.
1@torch.no_grad()
2def plot_test_reconstructions(
3 model: AE,
4 test_loader: DataLoader,
5 latent_dim: int,
6 num_cols: int = NUM_RECON_COLS,
7 fname: str | None = None,
8) -> None:
9 model.eval()
10
11 x_batch, _ = next(iter(test_loader))
12 idx = torch.randperm(x_batch.size(0))[:num_cols]
13 x = x_batch[idx].to(DEVICE)
14
15 out = model(x)
16 x_hat = torch.sigmoid(out.x_hat)
17
18 x = x.cpu().numpy()
19 x_hat = x_hat.cpu().numpy()
20
21 fig, axes = plt.subplots(2, num_cols, figsize=(2.5 * num_cols, 4.5))
22 if num_cols == 1:
23 axes = np.array([axes])
24
25 axes[0, 0].set_ylabel("Original", fontsize=18)
26 axes[1, 0].set_ylabel("Reconstruction", fontsize=18)
27
28 for r in range(num_cols):
29 axes[0, r].imshow(x[r, 0], cmap="gray", vmin=0, vmax=1)
30 axes[0, r].set_xticks([])
31 axes[0, r].set_yticks([])
32 axes[1, r].imshow(x_hat[r, 0], cmap="gray", vmin=0, vmax=1)
33 axes[1, r].set_xticks([])
34 axes[1, r].set_yticks([])
35
36 plt.tight_layout()
37 out_path = fname or f"ae_mnist_test_recon_latent{latent_dim}.png"
38 Path(out_path).parent.mkdir(parents=True, exist_ok=True)
39 plt.savefig(out_path, dpi=200, bbox_inches="tight")
40 plt.close(fig)
41 print(f"Saved figure -> {out_path}")
Putting It All Together#
The main() function orchestrates the full example: seeding,
dataset preparation, model creation, training, and saving a figure with
test-set reconstructions.
1def main() -> None:
2 print(f"\n=== Training AE (latent_dim={LATENT_DIM}) for {NUM_EPOCHS} epochs ===")
3 print(f"Using device: {DEVICE}")
4
5 set_seed(SEED)
6 train_loader, test_loader = make_dataloaders(BATCH_SIZE)
7
8 model = make_autoencoder(LATENT_DIM)
9 train(model, train_loader, NUM_EPOCHS)
10 plot_test_reconstructions(
11 model,
12 test_loader,
13 latent_dim=LATENT_DIM,
14 num_cols=NUM_RECON_COLS,
15 fname=f"ae_mnist_test_recon_latent{LATENT_DIM}.png",
16 )
17 print("Done.")
The script can also be executed directly:
1if __name__ == "__main__":
2 main()
Full Example Script#
For completeness, the entire example script is shown below:
1from __future__ import annotations
2
3from pathlib import Path
4import time
5import random
6
7import numpy as np
8import torch
9import torch.nn as nn
10from torch.utils.data import DataLoader
11from torchvision import datasets, transforms
12import matplotlib.pyplot as plt
13
14from pyautoencoder.vanilla import AE
15
16
17# ---------------- Config ---------------- #
18LATENT_DIM = 128
19NUM_EPOCHS = 30
20BATCH_SIZE = 128
21LEARNING_RATE = 1e-3
22NUM_RECON_COLS = 10
23
24DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25SEED = 1926
26
27# ---------------- Utils ----------------- #
28def set_seed(seed: int) -> None:
29 random.seed(seed)
30 np.random.seed(seed)
31 torch.manual_seed(seed)
32 if torch.cuda.is_available():
33 torch.cuda.manual_seed(seed)
34
35
36# ---------------- Data ------------------ #
37def make_dataloaders(batch_size: int) -> tuple[DataLoader, DataLoader]:
38 tfm = transforms.ToTensor() # maps to [0,1]
39 train_dataset = datasets.MNIST("./data", train=True, download=True, transform=tfm)
40 test_dataset = datasets.MNIST("./data", train=False, download=True, transform=tfm)
41
42 train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
43 test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
44 return train_dataloader, test_dataloader
45
46
47# ---------------- Model ----------------- #
48def make_autoencoder(latent_dim: int) -> AE:
49 encoder = nn.Sequential(
50 nn.Flatten(),
51 nn.Linear(28 * 28, 256),
52 nn.ReLU(),
53 nn.Linear(256, latent_dim),
54 )
55 decoder = nn.Sequential(
56 nn.Linear(latent_dim, 256),
57 nn.ReLU(),
58 nn.Linear(256, 28 * 28),
59 nn.Unflatten(-1, (1, 28, 28)), # keep last layer linear; AELoss(bernoulli) expects logits
60 )
61 model = AE(encoder=encoder, decoder=decoder)
62 model.build(input_sample=torch.randn(1, 1, 28, 28))
63 return model
64
65# ---------------- Train ----------------- #
66def train(
67 model: AE,
68 train_loader: DataLoader,
69 num_epochs: int,
70 lr: float = LEARNING_RATE,
71) -> None:
72 model.to(DEVICE)
73 model.train()
74 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
75
76 start_time = time.time()
77 for epoch in range(1, num_epochs + 1):
78 running_log_likelihood = 0.0
79 n = 0
80
81 for x, _ in train_loader:
82 x = x.to(DEVICE)
83 optimizer.zero_grad()
84
85 out = model(x)
86 loss_info = model.compute_loss(x, out, likelihood='bernoulli')
87 loss_info.objective.backward()
88 optimizer.step()
89
90 batch_size = x.size(0)
91 running_log_likelihood += loss_info.diagnostics['log_likelihood'] * batch_size
92 n += batch_size
93
94 avg_NLL = running_log_likelihood / n
95 elapsed = time.time() - start_time
96
97 print(
98 f"Epoch {epoch:3d}/{num_epochs} "
99 f"LogLikelihood={avg_NLL:.4f} | "
100 f"(elapsed {elapsed:.1f}s)"
101 )
102
103
104# ---------------- Plot (TEST samples) ---------------- #
105@torch.no_grad()
106def plot_test_reconstructions(
107 model: AE,
108 test_loader: DataLoader,
109 latent_dim: int,
110 num_cols: int = NUM_RECON_COLS,
111 fname: str | None = None,
112) -> None:
113 model.eval()
114
115 x_batch, _ = next(iter(test_loader))
116 idx = torch.randperm(x_batch.size(0))[:num_cols]
117 x = x_batch[idx].to(DEVICE)
118
119 out = model(x)
120 x_hat = torch.sigmoid(out.x_hat)
121
122 x = x.cpu().numpy()
123 x_hat = x_hat.cpu().numpy()
124
125 fig, axes = plt.subplots(2, num_cols, figsize=(2.5 * num_cols, 4.5))
126 if num_cols == 1:
127 axes = np.array([axes])
128
129 axes[0, 0].set_ylabel("Original", fontsize=18)
130 axes[1, 0].set_ylabel("Reconstruction", fontsize=18)
131
132 for r in range(num_cols):
133 axes[0, r].imshow(x[r, 0], cmap="gray", vmin=0, vmax=1)
134 axes[0, r].set_xticks([])
135 axes[0, r].set_yticks([])
136 axes[1, r].imshow(x_hat[r, 0], cmap="gray", vmin=0, vmax=1)
137 axes[1, r].set_xticks([])
138 axes[1, r].set_yticks([])
139
140 plt.tight_layout()
141 out_path = fname or f"ae_mnist_test_recon_latent{latent_dim}.png"
142 Path(out_path).parent.mkdir(parents=True, exist_ok=True)
143 plt.savefig(out_path, dpi=200, bbox_inches="tight")
144 plt.close(fig)
145 print(f"Saved figure -> {out_path}")
146
147
148# ---------------- Main ------------------ #
149def main() -> None:
150 print(f"\n=== Training AE (latent_dim={LATENT_DIM}) for {NUM_EPOCHS} epochs ===")
151 print(f"Using device: {DEVICE}")
152
153 set_seed(SEED)
154 train_loader, test_loader = make_dataloaders(BATCH_SIZE)
155
156 model = make_autoencoder(LATENT_DIM)
157 train(model, train_loader, NUM_EPOCHS)
158 plot_test_reconstructions(
159 model,
160 test_loader,
161 latent_dim=LATENT_DIM,
162 num_cols=NUM_RECON_COLS,
163 fname=f"ae_mnist_test_recon_latent{LATENT_DIM}.png",
164 )
165 print("Done.")
166
167
168if __name__ == "__main__":
169 main()