Base Classes & Utilities#

Core interfaces and utilities that form the foundation for all autoencoder implementations.

Internal base classes and utilities for autoencoders.

This module provides the abstract BaseAutoencoder interface and the output container ModelOutput. These form the foundation for all autoencoder implementations in this package.

Warning

This is an internal API. Most users should work with vanilla.AE or variational.VAE instead of using these base classes directly.

class pyautoencoder._base.BaseAutoencoder#

Bases: Module, ABC

Base class for autoencoders.

This class defines a common interface for autoencoders that split their logic into encoding, decoding and forward passes.

Training API (gradients enabled)#

Subclasses must implement the following abstract methods:

  • _encode(x, *args, **kwargs) -> ModelOutput Low-level encoder that typically returns a ModelOutput with at least a latent code attribute (for example z).

  • _decode(z, *args, **kwargs) -> ModelOutput Low-level decoder that typically returns a ModelOutput with at least a reconstruction attribute (for example x_hat).

  • forward(x, *args, **kwargs) -> ModelOutput Full forward pass used during training. This usually combines encoding and decoding and returns a ModelOutput that may contain both z and x_hat, plus any other quantities needed for loss computation.

Inference API (no gradients)#

For convenience, the class also exposes high-level inference helpers that are executed under torch.inference_mode():

  • encode() – calls _encode() without tracking gradients.

  • decode() – calls _decode() without tracking gradients.

Build step#

build() performs a no-grad forward pass with a representative input to materialize any lazy (size-inferred) layers. Call it once before training or loading a state dict.

build(input_sample: Tensor) None#

Materialize lazy layers with a representative input.

Performs a no-grad forward pass so that any LazyModuleMixin layers infer their shapes. Call this once before training or loading a state dict.

Parameters:

input_sample (torch.Tensor) – A representative input batch (e.g., a single batch from the training set). Only the shape matters; values are not used.

abstractmethod compute_loss(x: Tensor, output: ModelOutput, *args: Any, **kwargs: Any) LossResult#

Compute the loss for the autoencoder.

This abstract method must be implemented by subclasses to compute the appropriate loss objective for the model. Subclasses may support additional hyperparameters and configuration options.

Parameters:
  • x (torch.Tensor) – Ground-truth inputs of shape [B, ...].

  • output (ModelOutput) – Output from the forward pass, containing reconstructions, latent codes, and any other information needed to compute the loss.

  • *args – Additional positional arguments (subclass-specific).

  • **kwargs – Additional keyword arguments (subclass-specific).

Returns:

Result containing the loss objective and optional diagnostics.

Return type:

LossResult

decode(z: Tensor, *args: Any, **kwargs: Any) ModelOutput#

Decode latent codes without tracking gradients.

This is a thin wrapper around _decode() executed under torch.inference_mode(), making it suitable for evaluation-time decoding.

Parameters:
  • z (torch.Tensor) – Latent representation batch to decode.

  • *args – Additional positional arguments passed to _decode().

  • **kwargs – Additional keyword arguments passed to _decode().

Returns:

The decoder ModelOutput, typically containing at least a reconstruction (for example x_hat).

Return type:

ModelOutput

encode(x: Tensor, *args: Any, **kwargs: Any) ModelOutput#

Encode inputs without tracking gradients.

This is a thin wrapper around _encode() executed under torch.inference_mode(), making it suitable for evaluation-time encoding.

Parameters:
  • x (torch.Tensor) – Input batch to encode.

  • *args – Additional positional arguments passed to _encode().

  • **kwargs – Additional keyword arguments passed to _encode().

Returns:

The encoder ModelOutput, typically containing at least a latent code (for example z).

Return type:

ModelOutput

abstractmethod forward(x: Tensor, *args: Any, **kwargs: Any) ModelOutput#

Full training forward pass of the autoencoder.

This method is responsible for connecting the encoder and decoder and producing all outputs needed for training (for example latents, reconstructions and any auxiliary losses).

Parameters:
  • x (torch.Tensor) – Input batch to encode and decode.

  • *args – Additional positional arguments used by the subclass implementation.

  • **kwargs – Additional keyword arguments used by the subclass implementation.

Returns:

A model output object that typically includes both the latent code (for example z) and the reconstruction (for example x_hat), plus any other training-specific quantities.

Return type:

ModelOutput

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)#

Load a state dict, raising an error if lazy layers have not been built.

Parameters:
  • state_dict (Mapping[str, Any]) – State dictionary to load.

  • strict (bool, optional) – Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict(). Defaults to True.

  • assign (bool, optional) – Whether to assign tensors instead of copying them. Defaults to False.

Returns:

Named tuple with missing_keys and unexpected_keys fields, as returned by torch.nn.Module.load_state_dict().

Return type:

torch.nn.modules.module._IncompatibleKeys

Raises:

RuntimeError – If any LazyModuleMixin submodule has not yet been materialized. Call build() first.

class pyautoencoder._base.ModelOutput#

Bases: ABC

Base class for autoencoder outputs with a concise, tensor-aware repr.

Subclasses are dataclasses that group together tensors and auxiliary values produced by a model (for example latent codes, reconstructions, losses, etc.). The custom __repr__() implementation prints tensor fields using only their shape and dtype instead of full values, which keeps logs readable even for large tensors.

Notes

Any field whose value is a torch.Tensor is rendered as:

Tensor(shape=(...), dtype=...)

All other fields are rendered using repr().