Architecture & Design#
Overview#
PyAutoencoder provides clean, well-documented implementations of the following fundamental autoencoder architectures:
Autoencoder (AE) – Vanilla encoder-decoder pair
Variational Autoencoder (VAE) – Probabilistic encoder with sampling
Adaptive Group VAE (AdaGVAE) – Paired-input VAE for disentangled representations
All follow a consistent build/forward/compute_loss pattern and integrate naturally with PyTorch. AdaGVAE extends the VAE interface for paired inputs.
Vanilla Autoencoder (AE)#
The AE is a deterministic encoder-decoder model:
Architecture
Encoder – Maps input \(x\) to latent code \(z = f(x)\)
Decoder – Maps latent \(z\) back to reconstruction \(\hat{x} = g(z)\)
Forward pass – Returns both \(z\) and \(\hat{x}\)
Training
from pyautoencoder.vanilla import AE
model = AE(encoder=encoder_nn, decoder=decoder_nn)
model.build(sample_input)
# Training step
output = model(x_batch) # AEOutput with z and x_hat
loss_result = model.compute_loss(x_batch, output, likelihood='bernoulli')
loss_result.objective.backward()
Inference
# Extract latent codes
z = model.encode(x_batch).z # AEEncodeOutput → latent tensor
# Reconstruct from latents
x_reconstructed = model.decode(z).x_hat # AEDecodeOutput → reconstruction tensor
Output Structure
class AEOutput:
x_hat: torch.Tensor # Reconstruction [B, ...]
z: torch.Tensor # Latent code [B, D_z]
Variational Autoencoder (VAE)#
The VAE implements the VAE framework
(Kingma & Welling, 2013) with probabilistic inference:
Architecture
Encoder – Maps input \(x\) to latent distribution \(q(z|x)\)
Sampling layer – Produces mean \(\mu\), log-variance \(\log\sigma^2\)
Decoder – Maps sampled \(z\) to reconstruction distribution \(p(x|z)\)
Training
from pyautoencoder.variational import VAE
model = VAE(encoder=encoder_nn, decoder=decoder_nn, latent_dim=64)
model.build(sample_input)
# Training step (sample multiple times for Monte Carlo estimates)
output = model(x_batch, S=5) # 5 samples
loss_result = model.compute_loss(x_batch, output, beta=1.0, likelihood='gaussian')
loss_result.objective.backward()
Training vs Evaluation
Training mode (
model.train()): - Samples \(S\) latent codes per input - Enables Monte Carlo averaging of reconstruction loss - Returns shape \([B, S, D_z]\) for latentsEvaluation mode (
model.eval()): - Deterministic output (uses means, no sampling) - Faster inference - Still computes \(\mu\) and \(\log\sigma^2\) for diagnostics
Output Structure
class VAEOutput:
x_hat: torch.Tensor # Reconstructions [B, S, ...]
z: torch.Tensor # Samples [B, S, D_z]
mu: torch.Tensor # Posterior mean [B, D_z]
log_var: torch.Tensor # Posterior log-variance [B, D_z]
Sampling New Data
Generate samples from the prior \(p(z) = \mathcal{N}(0, I)\):
with torch.no_grad():
z_prior = torch.randn(n_samples, latent_dim)
x_samples = model.decoder(z_prior)
Adaptive Group Variational Autoencoder (AdaGVAE)#
The AdaGVAE implements the Ada-GVAE framework
(Locatello et al., 2020). It wraps a VAE and
adds an adaptive posterior-alignment step during training to encourage disentanglement.
Architecture
Backbone – A fully configured
VAE(encoder, sampling layer, decoder are reused)Paired encoding – Both inputs \(x_1\) and \(x_2\) are encoded independently to obtain \(q_1(z|x_1)\) and \(q_2(z|x_2)\)
Adaptive alignment – Dimensions where the per-dimension KL divergence \(\mathrm{KL}(q_1 \| q_2)\) falls below a threshold \(\tau\) are shared (posterior averaged); the rest are kept independent
Decoder – The standard VAE decoder is applied to samples from the adapted posteriors
The threshold is computed per sample as the midpoint of the min and max per-dimension KL values:
Training
from pyautoencoder.variational import VAE, AdaGVAE
vae = VAE(encoder=encoder_nn, decoder=decoder_nn, latent_dim=64)
model = AdaGVAE(vae=vae)
model.build(sample_input)
# Training step — forward takes a pair
output = model((x1_batch, x2_batch), S=5)
loss_result = model.compute_loss((x1_batch, x2_batch), output, beta=4.0, likelihood='bernoulli')
loss_result.objective.backward()
Inference
After training, AdaGVAE reuses the underlying VAE for single-image inference:
# Encode / decode through the wrapped VAE as usual
z = model.vae.encode(x).z
x_reconstructed = model.vae.decode(z).x_hat
Output Structure
class AdaGVAEOutput:
output1: VAEOutput # Adapted output for x1
output2: VAEOutput # Adapted output for x2
Each VAEOutput contains x_hat, z,
mu, and log_var for the corresponding adapted posterior.
Loss Functions#
The loss computation is integrated into each model via the compute_loss() method.
This approach keeps loss logic close to the model implementation and ensures consistency.
Reconstruction Loss (AE)#
For standard autoencoders, use AE.compute_loss() to compute reconstruction loss:
output = model(x)
loss_result = model.compute_loss(x, output, likelihood='gaussian')
loss_result.objective.backward()
Supported likelihoods:
Gaussian – Continuous data
\[\text{NLL} = \frac{1}{2}(x-\hat{x})^2\]Bernoulli – Discrete/binary data (logits)
\[\text{NLL} = \text{BCE}_{\text{logits}}(x, \hat{x})\]
ELBO Loss (VAE)#
For variational autoencoders, use VAE.compute_loss() to compute the negative ELBO:
output = model(x, S=5)
loss_result = model.compute_loss(x, output, beta=1.0, likelihood='gaussian')
loss_result.objective.backward()
The ELBO decomposes into reconstruction and regularization terms:
Beta-VAE Weighting
The \(\beta\) hyperparameter controls the KL weight:
\(\beta = 1.0\) – Standard VAE (matches true ELBO)
\(\beta > 1.0\) – Stronger regularization (more disentangled latents)
\(\beta < 1.0\) – Weaker regularization (better reconstruction)
Paired ELBO Loss (AdaGVAE)#
For AdaGVAE, use AdaGVAE.compute_loss() to compute the combined objective
over the pair:
output = model((x1, x2), S=5)
loss_result = model.compute_loss((x1, x2), output, beta=4.0, likelihood='bernoulli')
loss_result.objective.backward()
The loss is the sum of two independent beta-VAE ELBOs, each evaluated on the adapted posteriors \(q(\hat{z}|x_1)\) and \(q(\hat{z}|x_2)\):
The diagnostics dictionary exposes per-input reconstruction and KL terms
(log_likelihood_x1, log_likelihood_x2, kl_divergence_x1,
kl_divergence_x2) so both views can be monitored independently.
Loss Result Structure
AE.compute_loss(), VAE.compute_loss(), and AdaGVAE.compute_loss()
all return a LossResult with:
objective – scalar loss to optimize (backward-differentiable)
diagnostics – dictionary of scalar metrics (float values)
Example:
loss_result = model.compute_loss(x, output)
loss_result.objective.backward() # Optimize this
for name, val in loss_result.diagnostics.items():
log(name, val) # Monitor these
Key Design Principles#
User-Provided Networks
You provide the encoder and decoder as arbitrary PyTorch modules:
encoder = my_custom_encoder_network() # Any nn.Module
decoder = my_custom_decoder_network() # Any nn.Module
model = AE(encoder, decoder)
This keeps the library flexible and composable with existing PyTorch code.
Explicit Initialization
The build() method initializes size-dependent parameters:
model.build(sample_input) # Required once before training
This catches shape mismatches early and makes model behavior transparent.
Clear Loss Integration
Loss functions are methods on the models themselves, providing a clean API:
output = model(x_batch)
loss_result = model.compute_loss(x_batch, output)
loss_result.objective.backward() # Optimize this
for name, val in loss_result.diagnostics.items():
log(name, val) # Monitor these
Consistent Interfaces
All the architectures follow the same pattern:
model.build(sample) # Initialize
output = model(x_batch) # Forward pass
loss_result = model.compute_loss(x, output) # Compute loss
loss_result.objective.backward() # Backprop