Loss Functions & Utilities#

Core Loss Functions#

pyautoencoder.loss.log_likelihood(x: Tensor, x_hat: Tensor, likelihood: str | LikelihoodType = LikelihoodType.GAUSSIAN) Tensor#

Compute the elementwise log-likelihood \(\log p(x \mid \hat{x})\).

Two likelihood models are supported.

  • Gaussian (continuous data) Assuming fixed unit variance \(\sigma^2 = 1\), each element follows:

    \[\log p(x \mid \hat{x}) = -\tfrac{1}{2} (x - \hat{x})^2.\]

    The output has the same shape as x. Summing over feature dimensions gives per-sample log-likelihoods.

  • Bernoulli (discrete data) Here x_hat is interpreted as logits. Each element follows:

    \[\log p(x \mid \hat{x}) = x \log \sigma(\hat{x}) + (1 - x) \log\!\left( 1 - \sigma(\hat{x}) \right),\]

    where \(\sigma\) is the sigmoid. A numerically stable implementation using torch.nn.functional.binary_cross_entropy_with_logits() is applied.

Parameters:
  • x (torch.Tensor) – Ground-truth tensor of shape [B, ...].

  • x_hat (torch.Tensor) – Reconstructed tensor of shape [B, ...]. For the Bernoulli case, values are logits.

  • likelihood (str | LikelihoodType, optional) – Likelihood model to use. May be a string ("gaussian", "bernoulli") or a LikelihoodType enum value. Defaults to Gaussian.

Returns:

Elementwise log-likelihood with the same shape as x.

Return type:

torch.Tensor

Notes

  • The Gaussian case omits the normalization constant \(-\tfrac{1}{2}\log(2\pi)\), which is constant with respect to the model parameters and has no effect on optimization.

  • The Bernoulli case is fully numerically stable because it operates directly in log-space.

pyautoencoder.loss.kl_divergence_diag_gaussian(mu_q: Tensor, log_var_q: Tensor, mu_p: Tensor | None = None, log_var_p: Tensor | None = None, reduce_sum: bool = True) Tensor#

Compute the KL divergence \(\mathrm{KL}(q \,\|\, p)\) between two diagonal Gaussian distributions.

The first distribution is \(q = \mathcal{N}(\mu_q, \operatorname{diag}(\exp(\log \sigma_q^2)))\).

The second distribution is \(p = \mathcal{N}(\mu_p, \operatorname{diag}(\exp(\log \sigma_p^2)))\). When \(\mu_p\) and \(\log \sigma_p^2\) are None, \(p = \mathcal{N}(0, I)\).

The closed-form KL divergence is:

\[\mathrm{KL}(q \,\|\, p) = \frac{1}{2} \sum_{d} \left( (\log \sigma_{p,d}^2 - \log \sigma_{q,d}^2) + \frac{\exp(\log \sigma_{q,d}^2) + (\mu_{q,d} - \mu_{p,d})^2}{\exp(\log \sigma_{p,d}^2)} - 1 \right)\]
Parameters:
  • mu_q (torch.Tensor) – Mean of the first distribution, shape [B, D_z].

  • log_var_q (torch.Tensor) – Log-variance of the first distribution, shape [B, D_z].

  • mu_p (torch.Tensor or None, optional) – Mean of the second distribution, shape [B, D_z]. Defaults to None, which is treated as 0 (standard normal mean).

  • log_var_p (torch.Tensor or None, optional) – Log-variance of the second distribution, shape [B, D_z]. Defaults to None, which is treated as 0 (standard normal log-variance).

  • reduce_sum (bool, optional) – Sum over the dimensions. Defaults to True.

Returns:

KL divergences of shape [B] when reduce_sum=True, or [B, D_z] when reduce_sum=False.

Return type:

torch.Tensor

Data Structures & Types#

class pyautoencoder.loss.LossResult(objective: Tensor, diagnostics: dict[str, float])

Bases: object

Container for loss computation results with objective and diagnostics.

This dataclass holds the output of model loss computation methods (AE.compute_loss(), VAE.compute_loss(), etc.), separating the optimizable objective from optional diagnostic metrics.

objective

Scalar loss to optimize (e.g., negative log-likelihood or negative ELBO). Maintains gradient information for backpropagation.

Type:

torch.Tensor

diagnostics

Dictionary of scalar metrics for monitoring and logging. Values are detached float scalars (not tensors) and do not track gradients. Examples include log-likelihood, KL divergence, and ELBO.

Type:

dict[str, float]

diagnostics: dict[str, float]
objective: Tensor
class pyautoencoder.loss.LikelihoodType(value)#

Enumeration of supported decoder likelihood models \(p(x \mid z)\).

GAUSSIAN#

Gaussian likelihood with fixed unit variance \(\sigma^2 = 1\).

Type:

str

BERNOULLI#

Bernoulli likelihood for discrete data, with x_hat interpreted as logits.

Type:

str