Autoencoder Models#

Vanilla Autoencoder#

class pyautoencoder.vanilla.AE(encoder: Module, decoder: Module)#

Bases: BaseAutoencoder

Vanilla Autoencoder composed of a user-defined encoder and decoder.

The model follows the BaseAutoencoder interface and implements:

  • _encode(x) – maps inputs x to latent codes z.

  • _decode(z) – maps latent codes z to reconstructions x_hat.

  • forward(x) – full training forward pass returning both z and x_hat.

The encoder and decoder are arbitrary torch.nn.Module instances that define the mapping between data space and latent space.

__init__(encoder: Module, decoder: Module)#

Construct an Autoencoder from an encoder and decoder module.

Parameters:
  • encoder (nn.Module) – Module implementing the mapping x z.

  • decoder (nn.Module) – Module implementing the mapping z x_hat.

compute_loss(x: Tensor, ae_output: AEOutput, likelihood: str | LikelihoodType = LikelihoodType.GAUSSIAN) LossResult#

Compute Autoencoder reconstruction loss.

The scalar loss is the batch-mean reconstruction negative log-likelihood (NLL). The method also computes diagnostics to monitor model behavior.

Parameters:
  • x (torch.Tensor) – Ground-truth inputs of shape [B, ...].

  • ae_output (AEOutput) –

    Output from the AE forward pass. Expected fields include:

    • x_hat (torch.Tensor): Reconstructions, shape [B, ...].

    • z (torch.Tensor): Latent representation (unused by this method).

  • likelihood (str | LikelihoodType, optional) – Likelihood model for computing the reconstruction term ('gaussian' or 'bernoulli'). Defaults to Gaussian.

Returns:

Result containing:

  • objective – Scalar batch-mean reconstruction NLL (in nats).

  • diagnostics – Dictionary with:

    • "log_likelihood": Negative of the objective (batch-mean log-likelihood).

Return type:

LossResult

Notes

Reductions follow:

  1. Elementwise log-likelihood.

  2. Sum over feature dimensions.

  3. Mean over the batch.

Ensure that inputs match the chosen likelihood:

  • Gaussian: continuous data (typically standardized).

  • Bernoulli: targets in \([0, 1]\), predictions given as logits.

forward(x: Tensor) AEOutput#

Full training forward pass with gradients.

Parameters:

x (torch.Tensor) – Input batch of shape [B, ...].

Returns:

Dataclass containing both the reconstruction x_hat and the latent code z.

Return type:

AEOutput

Data Structures

class pyautoencoder.vanilla.AEEncodeOutput(z: Tensor)

Output of the Autoencoder encoder stage.

z

Latent code of shape [B, ...] produced by AE._encode() or AE.encode().

Type:

torch.Tensor

class pyautoencoder.vanilla.AEDecodeOutput(x_hat: Tensor)

Output of the Autoencoder decoder stage.

x_hat

Reconstruction (or logits) of shape [B, ...] produced by AE._decode() or AE.decode().

Type:

torch.Tensor

class pyautoencoder.vanilla.AEOutput(x_hat: Tensor, z: Tensor)

Output of the full Autoencoder forward pass.

x_hat

Reconstruction (or logits) of shape [B, ...] produced by AE.forward().

Type:

torch.Tensor

z

Latent code of shape [B, ...] produced by AE.forward().

Type:

torch.Tensor

Variational Autoencoder#

class pyautoencoder.variational.VAE(encoder: Module, decoder: Module, latent_dim: int)#

Bases: BaseAutoencoder

Variational Autoencoder following Kingma & Welling (2013).

The model consists of:

  • an encoder mapping x f(x) (feature representation),

  • a fully factorized Gaussian head producing (z, mu, log_var),

  • a decoder mapping latent samples z x_hat.

Training uses Monte Carlo samples z for the reparameterization trick; evaluation mode returns deterministic repeated means.

__init__(encoder: Module, decoder: Module, latent_dim: int)#

Construct a Variational Autoencoder from an encoder, decoder, and latent size.

Parameters:
  • encoder (nn.Module) – Maps input x to a feature vector f(x) with shape [B, F].

  • decoder (nn.Module) – Maps latent samples z to reconstructions x_hat.

  • latent_dim (int) – Dimensionality D_z of the latent space.

Notes

A FullyFactorizedGaussian sampling layer is created internally and not exposed as a constructor parameter.

compute_loss(x: Tensor, vae_output: VAEOutput, beta: float = 1, likelihood: str | LikelihoodType = LikelihoodType.GAUSSIAN) LossResult#

Compute the Evidence Lower Bound (ELBO) for a (beta-)Variational Autoencoder.

This method implements the beta-VAE objective:

\[\mathcal{L}(x; \beta) = \mathbb{E}_{q(z \mid x)}[\log p(x \mid z)] \;-\; \beta \, \mathrm{KL}(q(z \mid x) \,\|\, p(z)).\]

The reconstruction term \(\log p(x \mid z)\) is computed using loss.base.log_likelihood(), which supports both Gaussian and Bernoulli likelihoods.

Monte Carlo estimation#

If x_hat in vae_output contains S Monte Carlo samples, the expectation \(\mathbb{E}_{q(z \mid x)}\) is approximated by:

\[\mathbb{E}_{q(z \mid x)}[\log p(x \mid z)] \approx \frac{1}{S} \sum_{s=1}^{S} \log p(x \mid z^{(s)}).\]

Broadcasting#

  • If x_hat has shape [B, ...], it is expanded to [B, 1, ...].

  • x is broadcast to match the sample dimension of x_hat.

param x:

Ground-truth inputs of shape [B, ...].

type x:

torch.Tensor

param vae_output:

Output from the VAE forward pass. Expected fields include:

  • x_hat (torch.Tensor): Reconstructed samples, shape [B, ...] or [B, S, ...].

  • mu (torch.Tensor): Mean of \(q(z \mid x)\), shape [B, D_z].

  • log_var (torch.Tensor): Log-variance of \(q(z \mid x)\), shape [B, D_z].

type vae_output:

VAEOutput

param beta:

Weighting factor for the KL term (beta-VAE). beta = 1 yields the standard VAE. Defaults to 1.

type beta:

float, optional

param likelihood:

Likelihood model for the reconstruction term ('gaussian' or 'bernoulli'). Defaults to Gaussian.

type likelihood:

str | LikelihoodType, optional

returns:

Result containing:

  • objective – Negative mean ELBO (scalar).

  • diagnostics – Dictionary with:

    • "elbo": Mean ELBO over the batch.

    • "log_likelihood": Mean reconstruction term \(\mathbb{E}_{q}[\log p(x \mid z)]\).

    • "kl_divergence": Mean \(\mathrm{KL}(q \,\|\, p)\) over the batch.

rtype:

LossResult

Notes

  • All returned diagnostics are batch means.

  • Gradients flow through the decoder; neither input is detached.

forward(x: Tensor, S: int = 1) VAEOutput#

Full VAE pass: encode, sample S times, decode.

Parameters:
  • x (torch.Tensor) – Input batch of shape [B, ...].

  • S (int, optional) – Number of latent samples for Monte Carlo estimates. Defaults to 1.

Returns:

Contains reconstructions x_hat, latent samples z, and the posterior parameters mu and log_var.

Return type:

VAEOutput

Notes

If S > 1, loss computation can broadcast x to shape [B, S, ...] without materializing copies. For Bernoulli likelihoods, the decoder must output logits.

Data Structures

class pyautoencoder.variational.VAEEncodeOutput(z: Tensor, mu: Tensor, log_var: Tensor)

Output of the VAE encoder stage.

z

Latent samples of shape [B, S, D_z], produced by VAE._encode() or VAE.encode().

Type:

torch.Tensor

mu

Mean of the approximate posterior q(z \mid x), shape [B, D_z].

Type:

torch.Tensor

log_var

Log-variance of q(z \mid x), shape [B, D_z].

Type:

torch.Tensor

class pyautoencoder.variational.VAEDecodeOutput(x_hat: Tensor)

Output of the VAE decoder stage.

x_hat

Reconstructions or logits of shape [B, S, ...], produced by VAE._decode() or VAE.decode().

Type:

torch.Tensor

class pyautoencoder.variational.VAEOutput(x_hat: Tensor, z: Tensor, mu: Tensor, log_var: Tensor)

Output of a full VAE forward pass.

x_hat

Reconstructions or logits of shape [B, S, ...], produced by VAE.forward().

Type:

torch.Tensor

z

Latent samples of shape [B, S, D_z], produced by VAE.forward().

Type:

torch.Tensor

mu

Mean of q(z \mid x), shape [B, D_z].

Type:

torch.Tensor

log_var

Log-variance of q(z \mid x), shape [B, D_z].

Type:

torch.Tensor

Adaptive Group Variational Autoencoder#

class pyautoencoder.variational.AdaGVAE(vae: VAE)#

Bases: Module

Adaptive Group Variational Autoencoder (Ada-GVAE), from Locatello et al. (2020).

Wraps a VAE and adds adaptive posterior grouping for feature disentanglement. All VAE parameters are tracked through this wrapper.

forward() expects a pair of inputs (x1, x2) and returns an AdaGVAEOutput with adapted latent representations for both. For single-image inference after training, use model.vae.encode and model.vae.decode directly.

__init__(vae: VAE)#

Wrap a VAE with adaptive posterior grouping.

Parameters:

vae (VAE) – A configured VAE instance whose encoder, decoder, and sampling layer are reused for the paired training objective.

compute_loss(x: tuple[Tensor, Tensor], vae_output: AdaGVAEOutput, beta: float = 1, likelihood: str | LikelihoodType = LikelihoodType.GAUSSIAN) LossResult#

Compute the combined ELBO for a pair of inputs with adapted posteriors.

\[\mathcal{L}(x_1, x_2; \beta) = \left[ \mathbb{E}_{q(\hat{z} \mid x_1)}[\log p(x_1 \mid \hat{z})] \;-\; \beta \, \mathrm{KL}(q(\hat{z} \mid x_1) \,\|\, p(\hat{z})) \right] + \left[ \mathbb{E}_{q(\hat{z} \mid x_2)}[\log p(x_2 \mid \hat{z})] \;-\; \beta \, \mathrm{KL}(q(\hat{z} \mid x_2) \,\|\, p(\hat{z})) \right].\]
Parameters:
  • x (tuple[torch.Tensor, torch.Tensor]) – The (x1, x2) pair of ground-truth inputs, each of shape [B, ...].

  • vae_output (AdaGVAEOutput) – Output from forward() called in training mode.

  • beta (float, optional) – KL weighting factor. beta = 1 yields the standard objective. Defaults to 1.

  • likelihood (str | LikelihoodType, optional) – Likelihood model for the reconstruction term ('gaussian' or 'bernoulli'). Defaults to Gaussian.

Returns:

Result containing:

  • objective – Sum of negative ELBOs for both inputs (scalar).

  • diagnostics – Dictionary with:

    • "elbo": Sum of mean ELBOs for both inputs.

    • "log_likelihood_x1": Mean reconstruction term for x1.

    • "log_likelihood_x2": Mean reconstruction term for x2.

    • "kl_divergence_x1": Mean KL divergence for x1.

    • "kl_divergence_x2": Mean KL divergence for x2.

Return type:

LossResult

forward(x: tuple[Tensor, Tensor], S: int = 1) AdaGVAEOutput#

AdaGVAE training pass on a pair of images.

For single-image inference after training use model.vae.encode and model.vae.decode.

Parameters:
  • x (tuple[torch.Tensor, torch.Tensor]) – A (x1, x2) pair, each of shape [B, ...].

  • S (int, optional) – Number of latent samples for Monte Carlo estimates. Defaults to 1.

Returns:

Adapted pair outputs containing reconstructions and posterior parameters for both inputs.

Return type:

AdaGVAEOutput