Base Classes & Utilities#
Core interfaces and utilities that form the foundation for all autoencoder implementations.
Internal base classes and utilities for autoencoders.
This module provides the abstract BaseAutoencoder interface and the
output container ModelOutput. These form the foundation for all
autoencoder implementations in this package.
Warning
This is an internal API. Most users should work with vanilla.AE or
variational.VAE instead of using these base classes directly.
- class pyautoencoder._base.BaseAutoencoder#
Bases:
Module,ABCBase class for autoencoders.
This class defines a common interface for autoencoders that split their logic into encoding, decoding and forward passes.
Training API (gradients enabled)#
Subclasses must implement the following abstract methods:
_encode(x, *args, **kwargs) -> ModelOutputLow-level encoder that typically returns aModelOutputwith at least a latent code attribute (for examplez)._decode(z, *args, **kwargs) -> ModelOutputLow-level decoder that typically returns aModelOutputwith at least a reconstruction attribute (for examplex_hat).forward(x, *args, **kwargs) -> ModelOutputFull forward pass used during training. This usually combines encoding and decoding and returns aModelOutputthat may contain bothzandx_hat, plus any other quantities needed for loss computation.
Inference API (no gradients)#
For convenience, the class also exposes high-level inference helpers that are executed under
torch.inference_mode():Build step#
build()performs a no-grad forward pass with a representative input to materialize any lazy (size-inferred) layers. Call it once before training or loading a state dict.- build(input_sample: Tensor) None#
Materialize lazy layers with a representative input.
Performs a no-grad forward pass so that any
LazyModuleMixinlayers infer their shapes. Call this once before training or loading a state dict.- Parameters:
input_sample (torch.Tensor) – A representative input batch (e.g., a single batch from the training set). Only the shape matters; values are not used.
- abstractmethod compute_loss(x: Tensor, output: ModelOutput, *args: Any, **kwargs: Any) LossResult#
Compute the loss for the autoencoder.
This abstract method must be implemented by subclasses to compute the appropriate loss objective for the model. Subclasses may support additional hyperparameters and configuration options.
- Parameters:
x (torch.Tensor) – Ground-truth inputs of shape
[B, ...].output (ModelOutput) – Output from the forward pass, containing reconstructions, latent codes, and any other information needed to compute the loss.
*args – Additional positional arguments (subclass-specific).
**kwargs – Additional keyword arguments (subclass-specific).
- Returns:
Result containing the loss objective and optional diagnostics.
- Return type:
LossResult
- decode(z: Tensor, *args: Any, **kwargs: Any) ModelOutput#
Decode latent codes without tracking gradients.
This is a thin wrapper around
_decode()executed undertorch.inference_mode(), making it suitable for evaluation-time decoding.- Parameters:
z (torch.Tensor) – Latent representation batch to decode.
*args – Additional positional arguments passed to
_decode().**kwargs – Additional keyword arguments passed to
_decode().
- Returns:
The decoder
ModelOutput, typically containing at least a reconstruction (for examplex_hat).- Return type:
- encode(x: Tensor, *args: Any, **kwargs: Any) ModelOutput#
Encode inputs without tracking gradients.
This is a thin wrapper around
_encode()executed undertorch.inference_mode(), making it suitable for evaluation-time encoding.- Parameters:
x (torch.Tensor) – Input batch to encode.
*args – Additional positional arguments passed to
_encode().**kwargs – Additional keyword arguments passed to
_encode().
- Returns:
The encoder
ModelOutput, typically containing at least a latent code (for examplez).- Return type:
- abstractmethod forward(x: Tensor, *args: Any, **kwargs: Any) ModelOutput#
Full training forward pass of the autoencoder.
This method is responsible for connecting the encoder and decoder and producing all outputs needed for training (for example latents, reconstructions and any auxiliary losses).
- Parameters:
x (torch.Tensor) – Input batch to encode and decode.
*args – Additional positional arguments used by the subclass implementation.
**kwargs – Additional keyword arguments used by the subclass implementation.
- Returns:
A model output object that typically includes both the latent code (for example
z) and the reconstruction (for examplex_hat), plus any other training-specific quantities.- Return type:
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)#
Load a state dict, raising an error if lazy layers have not been built.
- Parameters:
state_dict (Mapping[str, Any]) – State dictionary to load.
strict (bool, optional) – Whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict(). Defaults toTrue.assign (bool, optional) – Whether to assign tensors instead of copying them. Defaults to
False.
- Returns:
Named tuple with
missing_keysandunexpected_keysfields, as returned bytorch.nn.Module.load_state_dict().- Return type:
torch.nn.modules.module._IncompatibleKeys
- Raises:
RuntimeError – If any
LazyModuleMixinsubmodule has not yet been materialized. Callbuild()first.
- class pyautoencoder._base.ModelOutput#
Bases:
ABCBase class for autoencoder outputs with a concise, tensor-aware
repr.Subclasses are dataclasses that group together tensors and auxiliary values produced by a model (for example latent codes, reconstructions, losses, etc.). The custom
__repr__()implementation prints tensor fields using only their shape and dtype instead of full values, which keeps logs readable even for large tensors.Notes
Any field whose value is a
torch.Tensoris rendered as:Tensor(shape=(...), dtype=...)
All other fields are rendered using
repr().