31 lines
1 KiB
Python
31 lines
1 KiB
Python
from dataclasses import dataclass
|
|
|
|
from ..utils import BaseOutput
|
|
|
|
|
|
@dataclass
|
|
class AutoencoderKLOutput(BaseOutput):
|
|
"""
|
|
Output of AutoencoderKL encoding method.
|
|
|
|
Args:
|
|
latent_dist (`DiagonalGaussianDistribution`):
|
|
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
|
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
|
"""
|
|
|
|
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
|
|
|
|
|
@dataclass
|
|
class Transformer2DModelOutput(BaseOutput):
|
|
"""
|
|
The output of [`Transformer2DModel`].
|
|
|
|
Args:
|
|
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
|
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
|
distributions for the unnoised latent pixels.
|
|
"""
|
|
|
|
sample: "torch.Tensor" # noqa: F821
|