MNIST Vanilla Autoencoder Example#

This example demonstrates how to train a fully connected autoencoder on the MNIST dataset using pyautoencoder.

Configuration and Utilities#

We begin by importing the required libraries, defining the main hyperparameters, and setting up deterministic random seeds to ensure reproducibility.

 1from __future__ import annotations
 2
 3from pathlib import Path
 4import time
 5import random
 6
 7import numpy as np
 8import torch
 9import torch.nn as nn
10from torch.utils.data import DataLoader
11from torchvision import datasets, transforms
12import matplotlib.pyplot as plt
13
14from pyautoencoder.vanilla import AE
15
16
17# ---------------- Config ---------------- #
18LATENT_DIM = 128
19NUM_EPOCHS = 30
20BATCH_SIZE = 128
21LEARNING_RATE = 1e-3
22NUM_RECON_COLS = 10
23
24DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25SEED = 1926
26
27# ---------------- Utils ----------------- #
28def set_seed(seed: int) -> None:
29    random.seed(seed)
30    np.random.seed(seed)
31    torch.manual_seed(seed)
32    if torch.cuda.is_available():
33        torch.cuda.manual_seed(seed)

Key elements in this section include:

  • LATENT_DIM, NUM_EPOCHS, BATCH_SIZE – the primary model and training hyperparameters,

  • DEVICE – automatically selects "cuda" when available,

  • set_seed() – helper function to provide reproducible runs.

Data Loading#

We prepare the MNIST training and test datasets using torchvision.datasets.MNIST. A simple ToTensor transform is used to map images to the \([0, 1]\) range.

1def make_dataloaders(batch_size: int) -> tuple[DataLoader, DataLoader]:
2    tfm = transforms.ToTensor()  # maps to [0,1]
3    train_dataset = datasets.MNIST("./data", train=True,  download=True, transform=tfm)
4    test_dataset  = datasets.MNIST("./data", train=False, download=True, transform=tfm)
5
6    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
7    test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)
8    return train_dataloader, test_dataloader

The make_dataloaders() function returns two torch.utils.data.DataLoader objects that supply batches to the training and evaluation loops.

Model Definition#

The autoencoder consists of a compact fully connected encoder and decoder. The encoder flattens the input and maps it into a latent representation of dimension LATENT_DIM; the decoder reconstructs the image from this latent vector.

 1def make_autoencoder(latent_dim: int) -> AE:
 2    encoder = nn.Sequential(
 3        nn.Flatten(),
 4        nn.Linear(28 * 28, 256),
 5        nn.ReLU(),
 6        nn.Linear(256, latent_dim),
 7    )
 8    decoder = nn.Sequential(
 9        nn.Linear(latent_dim, 256),
10        nn.ReLU(),
11        nn.Linear(256, 28 * 28),
12        nn.Unflatten(-1, (1, 28, 28)),  # keep last layer linear; AELoss(bernoulli) expects logits
13    )
14    model = AE(encoder=encoder, decoder=decoder)
15    model.build(input_sample=torch.randn(1, 1, 28, 28))
16    return model

Notes:

  • The final decoder layer is linear and produces logits.

  • The model is explicitly built through AE.build(), which infers required shapes from a representative sample.

  • Reconstruction loss is computed via the AE.compute_loss() method, which uses a Bernoulli likelihood (interpreting the decoder output as logits).

Training Loop#

Training uses standard mini-batch gradient descent with the torch.optim.Adam optimizer. During each iteration, we compute the reconstruction log-likelihood and related diagnostics.

 1def train(
 2    model: AE,
 3    train_loader: DataLoader,
 4    num_epochs: int,
 5    lr: float = LEARNING_RATE,
 6) -> None:
 7    model.to(DEVICE)
 8    model.train()
 9    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
10
11    start_time = time.time()
12    for epoch in range(1, num_epochs + 1):
13        running_log_likelihood = 0.0
14        n = 0
15
16        for x, _ in train_loader:
17            x = x.to(DEVICE)
18            optimizer.zero_grad()
19
20            out = model(x)
21            loss_info = model.compute_loss(x, out, likelihood='bernoulli')
22            loss_info.objective.backward()
23            optimizer.step()
24
25            batch_size = x.size(0)
26            running_log_likelihood += loss_info.diagnostics['log_likelihood'] * batch_size
27            n += batch_size
28
29        avg_NLL = running_log_likelihood / n
30        elapsed = time.time() - start_time
31
32        print(
33            f"Epoch {epoch:3d}/{num_epochs}  "
34            f"LogLikelihood={avg_NLL:.4f} | "
35            f"(elapsed {elapsed:.1f}s)"
36        )

For each batch:

  • The model produces latent codes and reconstructions: out = model(x).

  • AE.compute_loss() returns a LossResult containing:

    • objective – batch-mean negative log-likelihood (NLL) in nats,

    • diagnostics['log_likelihood'] – batch-mean log-likelihood (negative of objective).

  • The loss is backpropagated through the entire model.

  • These quantities are accumulated over the epoch and reported in the log.

Visualizing Reconstructions#

After training completes, we visualize a few randomly chosen test images alongside their reconstructions. The decoder outputs logits, so we apply a sigmoid to convert them to pixel intensities suitable for display.

 1@torch.no_grad()
 2def plot_test_reconstructions(
 3    model: AE,
 4    test_loader: DataLoader,
 5    latent_dim: int,
 6    num_cols: int = NUM_RECON_COLS,
 7    fname: str | None = None,
 8) -> None:
 9    model.eval()
10
11    x_batch, _ = next(iter(test_loader))
12    idx = torch.randperm(x_batch.size(0))[:num_cols]
13    x = x_batch[idx].to(DEVICE)
14
15    out = model(x)
16    x_hat = torch.sigmoid(out.x_hat)
17
18    x = x.cpu().numpy()
19    x_hat = x_hat.cpu().numpy()
20
21    fig, axes = plt.subplots(2, num_cols, figsize=(2.5 * num_cols, 4.5))
22    if num_cols == 1:
23        axes = np.array([axes])
24
25    axes[0, 0].set_ylabel("Original", fontsize=18)
26    axes[1, 0].set_ylabel("Reconstruction", fontsize=18)
27
28    for r in range(num_cols):
29        axes[0, r].imshow(x[r, 0], cmap="gray", vmin=0, vmax=1)
30        axes[0, r].set_xticks([])
31        axes[0, r].set_yticks([])
32        axes[1, r].imshow(x_hat[r, 0], cmap="gray", vmin=0, vmax=1)
33        axes[1, r].set_xticks([])
34        axes[1, r].set_yticks([])
35
36    plt.tight_layout()
37    out_path = fname or f"ae_mnist_test_recon_latent{latent_dim}.png"
38    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
39    plt.savefig(out_path, dpi=200, bbox_inches="tight")
40    plt.close(fig)
41    print(f"Saved figure -> {out_path}")

Putting It All Together#

The main() function orchestrates the full example: seeding, dataset preparation, model creation, training, and saving a figure with test-set reconstructions.

 1def main() -> None:
 2    print(f"\n=== Training AE (latent_dim={LATENT_DIM}) for {NUM_EPOCHS} epochs ===")
 3    print(f"Using device: {DEVICE}")
 4
 5    set_seed(SEED)
 6    train_loader, test_loader = make_dataloaders(BATCH_SIZE)
 7
 8    model = make_autoencoder(LATENT_DIM)
 9    train(model, train_loader, NUM_EPOCHS)
10    plot_test_reconstructions(
11        model,
12        test_loader,
13        latent_dim=LATENT_DIM,
14        num_cols=NUM_RECON_COLS,
15        fname=f"ae_mnist_test_recon_latent{LATENT_DIM}.png",
16    )
17    print("Done.")

The script can also be executed directly:

1if __name__ == "__main__":
2    main()

Full Example Script#

For completeness, the entire example script is shown below:

mnist_ae.py#
  1from __future__ import annotations
  2
  3from pathlib import Path
  4import time
  5import random
  6
  7import numpy as np
  8import torch
  9import torch.nn as nn
 10from torch.utils.data import DataLoader
 11from torchvision import datasets, transforms
 12import matplotlib.pyplot as plt
 13
 14from pyautoencoder.vanilla import AE
 15
 16
 17# ---------------- Config ---------------- #
 18LATENT_DIM = 128
 19NUM_EPOCHS = 30
 20BATCH_SIZE = 128
 21LEARNING_RATE = 1e-3
 22NUM_RECON_COLS = 10
 23
 24DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 25SEED = 1926
 26
 27# ---------------- Utils ----------------- #
 28def set_seed(seed: int) -> None:
 29    random.seed(seed)
 30    np.random.seed(seed)
 31    torch.manual_seed(seed)
 32    if torch.cuda.is_available():
 33        torch.cuda.manual_seed(seed)
 34
 35
 36# ---------------- Data ------------------ #
 37def make_dataloaders(batch_size: int) -> tuple[DataLoader, DataLoader]:
 38    tfm = transforms.ToTensor()  # maps to [0,1]
 39    train_dataset = datasets.MNIST("./data", train=True,  download=True, transform=tfm)
 40    test_dataset  = datasets.MNIST("./data", train=False, download=True, transform=tfm)
 41
 42    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 43    test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)
 44    return train_dataloader, test_dataloader
 45
 46
 47# ---------------- Model ----------------- #
 48def make_autoencoder(latent_dim: int) -> AE:
 49    encoder = nn.Sequential(
 50        nn.Flatten(),
 51        nn.Linear(28 * 28, 256),
 52        nn.ReLU(),
 53        nn.Linear(256, latent_dim),
 54    )
 55    decoder = nn.Sequential(
 56        nn.Linear(latent_dim, 256),
 57        nn.ReLU(),
 58        nn.Linear(256, 28 * 28),
 59        nn.Unflatten(-1, (1, 28, 28)),  # keep last layer linear; AELoss(bernoulli) expects logits
 60    )
 61    model = AE(encoder=encoder, decoder=decoder)
 62    model.build(input_sample=torch.randn(1, 1, 28, 28))
 63    return model
 64
 65# ---------------- Train ----------------- #
 66def train(
 67    model: AE,
 68    train_loader: DataLoader,
 69    num_epochs: int,
 70    lr: float = LEARNING_RATE,
 71) -> None:
 72    model.to(DEVICE)
 73    model.train()
 74    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
 75
 76    start_time = time.time()
 77    for epoch in range(1, num_epochs + 1):
 78        running_log_likelihood = 0.0
 79        n = 0
 80
 81        for x, _ in train_loader:
 82            x = x.to(DEVICE)
 83            optimizer.zero_grad()
 84
 85            out = model(x)
 86            loss_info = model.compute_loss(x, out, likelihood='bernoulli')
 87            loss_info.objective.backward()
 88            optimizer.step()
 89
 90            batch_size = x.size(0)
 91            running_log_likelihood += loss_info.diagnostics['log_likelihood'] * batch_size
 92            n += batch_size
 93
 94        avg_NLL = running_log_likelihood / n
 95        elapsed = time.time() - start_time
 96
 97        print(
 98            f"Epoch {epoch:3d}/{num_epochs}  "
 99            f"LogLikelihood={avg_NLL:.4f} | "
100            f"(elapsed {elapsed:.1f}s)"
101        )
102
103
104# ---------------- Plot (TEST samples) ---------------- #
105@torch.no_grad()
106def plot_test_reconstructions(
107    model: AE,
108    test_loader: DataLoader,
109    latent_dim: int,
110    num_cols: int = NUM_RECON_COLS,
111    fname: str | None = None,
112) -> None:
113    model.eval()
114
115    x_batch, _ = next(iter(test_loader))
116    idx = torch.randperm(x_batch.size(0))[:num_cols]
117    x = x_batch[idx].to(DEVICE)
118
119    out = model(x)
120    x_hat = torch.sigmoid(out.x_hat)
121
122    x = x.cpu().numpy()
123    x_hat = x_hat.cpu().numpy()
124
125    fig, axes = plt.subplots(2, num_cols, figsize=(2.5 * num_cols, 4.5))
126    if num_cols == 1:
127        axes = np.array([axes])
128
129    axes[0, 0].set_ylabel("Original", fontsize=18)
130    axes[1, 0].set_ylabel("Reconstruction", fontsize=18)
131
132    for r in range(num_cols):
133        axes[0, r].imshow(x[r, 0], cmap="gray", vmin=0, vmax=1)
134        axes[0, r].set_xticks([])
135        axes[0, r].set_yticks([])
136        axes[1, r].imshow(x_hat[r, 0], cmap="gray", vmin=0, vmax=1)
137        axes[1, r].set_xticks([])
138        axes[1, r].set_yticks([])
139
140    plt.tight_layout()
141    out_path = fname or f"ae_mnist_test_recon_latent{latent_dim}.png"
142    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
143    plt.savefig(out_path, dpi=200, bbox_inches="tight")
144    plt.close(fig)
145    print(f"Saved figure -> {out_path}")
146
147
148# ---------------- Main ------------------ #
149def main() -> None:
150    print(f"\n=== Training AE (latent_dim={LATENT_DIM}) for {NUM_EPOCHS} epochs ===")
151    print(f"Using device: {DEVICE}")
152
153    set_seed(SEED)
154    train_loader, test_loader = make_dataloaders(BATCH_SIZE)
155
156    model = make_autoencoder(LATENT_DIM)
157    train(model, train_loader, NUM_EPOCHS)
158    plot_test_reconstructions(
159        model,
160        test_loader,
161        latent_dim=LATENT_DIM,
162        num_cols=NUM_RECON_COLS,
163        fname=f"ae_mnist_test_recon_latent{LATENT_DIM}.png",
164    )
165    print("Done.")
166
167
168if __name__ == "__main__":
169    main()