Architecture & Design#

Overview#

PyAutoencoder provides clean, well-documented implementations of the following fundamental autoencoder architectures:

  1. Autoencoder (AE) – Vanilla encoder-decoder pair

  2. Variational Autoencoder (VAE) – Probabilistic encoder with sampling

  3. Adaptive Group VAE (AdaGVAE) – Paired-input VAE for disentangled representations

All follow a consistent build/forward/compute_loss pattern and integrate naturally with PyTorch. AdaGVAE extends the VAE interface for paired inputs.

Vanilla Autoencoder (AE)#

The AE is a deterministic encoder-decoder model:

Architecture

  • Encoder – Maps input \(x\) to latent code \(z = f(x)\)

  • Decoder – Maps latent \(z\) back to reconstruction \(\hat{x} = g(z)\)

  • Forward pass – Returns both \(z\) and \(\hat{x}\)

Training

from pyautoencoder.vanilla import AE

model = AE(encoder=encoder_nn, decoder=decoder_nn)
model.build(sample_input)

# Training step
output = model(x_batch)          # AEOutput with z and x_hat
loss_result = model.compute_loss(x_batch, output, likelihood='bernoulli')
loss_result.objective.backward()

Inference

# Extract latent codes
z = model.encode(x_batch).z              # AEEncodeOutput → latent tensor

# Reconstruct from latents
x_reconstructed = model.decode(z).x_hat  # AEDecodeOutput → reconstruction tensor

Output Structure

class AEOutput:
    x_hat: torch.Tensor       # Reconstruction [B, ...]
    z: torch.Tensor           # Latent code [B, D_z]

Variational Autoencoder (VAE)#

The VAE implements the VAE framework (Kingma & Welling, 2013) with probabilistic inference:

Architecture

  • Encoder – Maps input \(x\) to latent distribution \(q(z|x)\)

  • Sampling layer – Produces mean \(\mu\), log-variance \(\log\sigma^2\)

  • Decoder – Maps sampled \(z\) to reconstruction distribution \(p(x|z)\)

Training

from pyautoencoder.variational import VAE

model = VAE(encoder=encoder_nn, decoder=decoder_nn, latent_dim=64)
model.build(sample_input)

# Training step (sample multiple times for Monte Carlo estimates)
output = model(x_batch, S=5)  # 5 samples
loss_result = model.compute_loss(x_batch, output, beta=1.0, likelihood='gaussian')
loss_result.objective.backward()

Training vs Evaluation

  • Training mode (model.train()): - Samples \(S\) latent codes per input - Enables Monte Carlo averaging of reconstruction loss - Returns shape \([B, S, D_z]\) for latents

  • Evaluation mode (model.eval()): - Deterministic output (uses means, no sampling) - Faster inference - Still computes \(\mu\) and \(\log\sigma^2\) for diagnostics

Output Structure

class VAEOutput:
    x_hat: torch.Tensor       # Reconstructions [B, S, ...]
    z: torch.Tensor           # Samples [B, S, D_z]
    mu: torch.Tensor          # Posterior mean [B, D_z]
    log_var: torch.Tensor     # Posterior log-variance [B, D_z]

Sampling New Data

Generate samples from the prior \(p(z) = \mathcal{N}(0, I)\):

with torch.no_grad():
    z_prior = torch.randn(n_samples, latent_dim)
    x_samples = model.decoder(z_prior)

Adaptive Group Variational Autoencoder (AdaGVAE)#

The AdaGVAE implements the Ada-GVAE framework (Locatello et al., 2020). It wraps a VAE and adds an adaptive posterior-alignment step during training to encourage disentanglement.

Architecture

  • Backbone – A fully configured VAE (encoder, sampling layer, decoder are reused)

  • Paired encoding – Both inputs \(x_1\) and \(x_2\) are encoded independently to obtain \(q_1(z|x_1)\) and \(q_2(z|x_2)\)

  • Adaptive alignment – Dimensions where the per-dimension KL divergence \(\mathrm{KL}(q_1 \| q_2)\) falls below a threshold \(\tau\) are shared (posterior averaged); the rest are kept independent

  • Decoder – The standard VAE decoder is applied to samples from the adapted posteriors

The threshold is computed per sample as the midpoint of the min and max per-dimension KL values:

\[\tau = \tfrac{1}{2}(\max_d \mathrm{KL}_d + \min_d \mathrm{KL}_d)\]

Training

from pyautoencoder.variational import VAE, AdaGVAE

vae = VAE(encoder=encoder_nn, decoder=decoder_nn, latent_dim=64)
model = AdaGVAE(vae=vae)
model.build(sample_input)

# Training step — forward takes a pair
output = model((x1_batch, x2_batch), S=5)
loss_result = model.compute_loss((x1_batch, x2_batch), output, beta=4.0, likelihood='bernoulli')
loss_result.objective.backward()

Inference

After training, AdaGVAE reuses the underlying VAE for single-image inference:

# Encode / decode through the wrapped VAE as usual
z = model.vae.encode(x).z
x_reconstructed = model.vae.decode(z).x_hat

Output Structure

class AdaGVAEOutput:
    output1: VAEOutput    # Adapted output for x1
    output2: VAEOutput    # Adapted output for x2

Each VAEOutput contains x_hat, z, mu, and log_var for the corresponding adapted posterior.

Loss Functions#

The loss computation is integrated into each model via the compute_loss() method. This approach keeps loss logic close to the model implementation and ensures consistency.

Reconstruction Loss (AE)#

For standard autoencoders, use AE.compute_loss() to compute reconstruction loss:

output = model(x)
loss_result = model.compute_loss(x, output, likelihood='gaussian')
loss_result.objective.backward()

Supported likelihoods:

  • Gaussian – Continuous data

    \[\text{NLL} = \frac{1}{2}(x-\hat{x})^2\]
  • Bernoulli – Discrete/binary data (logits)

    \[\text{NLL} = \text{BCE}_{\text{logits}}(x, \hat{x})\]

ELBO Loss (VAE)#

For variational autoencoders, use VAE.compute_loss() to compute the negative ELBO:

output = model(x, S=5)
loss_result = model.compute_loss(x, output, beta=1.0, likelihood='gaussian')
loss_result.objective.backward()

The ELBO decomposes into reconstruction and regularization terms:

\[\mathcal{L}(x; \beta) = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \beta \, \underbrace{\mathrm{KL}(q(z|x) \| p(z))}_{\text{Regularization}}\]

Beta-VAE Weighting

The \(\beta\) hyperparameter controls the KL weight:

  • \(\beta = 1.0\) – Standard VAE (matches true ELBO)

  • \(\beta > 1.0\) – Stronger regularization (more disentangled latents)

  • \(\beta < 1.0\) – Weaker regularization (better reconstruction)

Paired ELBO Loss (AdaGVAE)#

For AdaGVAE, use AdaGVAE.compute_loss() to compute the combined objective over the pair:

output = model((x1, x2), S=5)
loss_result = model.compute_loss((x1, x2), output, beta=4.0, likelihood='bernoulli')
loss_result.objective.backward()

The loss is the sum of two independent beta-VAE ELBOs, each evaluated on the adapted posteriors \(q(\hat{z}|x_1)\) and \(q(\hat{z}|x_2)\):

\[\mathcal{L}(x_1, x_2; \beta) = \Bigl[\mathbb{E}_{q(\hat{z}|x_1)}[\log p(x_1|\hat{z})] - \beta\,\mathrm{KL}(q(\hat{z}|x_1)\|p(\hat{z}))\Bigr] + \Bigl[\mathbb{E}_{q(\hat{z}|x_2)}[\log p(x_2|\hat{z})] - \beta\,\mathrm{KL}(q(\hat{z}|x_2)\|p(\hat{z}))\Bigr]\]

The diagnostics dictionary exposes per-input reconstruction and KL terms (log_likelihood_x1, log_likelihood_x2, kl_divergence_x1, kl_divergence_x2) so both views can be monitored independently.

Loss Result Structure

AE.compute_loss(), VAE.compute_loss(), and AdaGVAE.compute_loss() all return a LossResult with:

  • objective – scalar loss to optimize (backward-differentiable)

  • diagnostics – dictionary of scalar metrics (float values)

Example:

loss_result = model.compute_loss(x, output)
loss_result.objective.backward()              # Optimize this
for name, val in loss_result.diagnostics.items():
    log(name, val)                            # Monitor these

Key Design Principles#

User-Provided Networks

You provide the encoder and decoder as arbitrary PyTorch modules:

encoder = my_custom_encoder_network()  # Any nn.Module
decoder = my_custom_decoder_network()  # Any nn.Module
model = AE(encoder, decoder)

This keeps the library flexible and composable with existing PyTorch code.

Explicit Initialization

The build() method initializes size-dependent parameters:

model.build(sample_input)  # Required once before training

This catches shape mismatches early and makes model behavior transparent.

Clear Loss Integration

Loss functions are methods on the models themselves, providing a clean API:

output = model(x_batch)
loss_result = model.compute_loss(x_batch, output)
loss_result.objective.backward()              # Optimize this
for name, val in loss_result.diagnostics.items():
    log(name, val)                            # Monitor these

Consistent Interfaces

All the architectures follow the same pattern:

model.build(sample)                          # Initialize
output = model(x_batch)                      # Forward pass
loss_result = model.compute_loss(x, output)  # Compute loss
loss_result.objective.backward()             # Backprop