YouTube Courses - Learn Smarter

YouTube-Courses.site transforms YouTube videos and playlists into structured courses, making learning more efficient and organized.

Building a Variational Autoencoder (VAE) from Scratch with PyTorch: An Educational Guide



This chapter provides a comprehensive guide to understanding, implementing, and training a Variational Autoencoder (VAE) using PyTorch. VAEs are powerful neural networks with significant applications, particularly in generative models like Stable Diffusion and DALL-E. This chapter will cover the theoretical foundations of VAEs, their architecture, implementation details, and training process.

1. Introduction to Variational Autoencoders (VAEs)

Variational Autoencoders (VAEs) are a type of autoencoder and a powerful class of neural networks trained using unsupervised learning techniques. They have gained prominence in the field of image generation, especially within latent diffusion models and Generative Adversarial Networks (GANs).

Variational Autoencoder (VAE): A type of generative neural network that learns a probabilistic latent space representation of data. Unlike traditional autoencoders, VAEs learn a distribution over the latent space, enabling them to generate new data samples.

1.1. The Role of VAEs in Image Generation

VAEs play a crucial role in modern image generation models such as Stable Diffusion and DALL-E. These models often operate on a compressed representation of images to reduce computational cost and improve efficiency. VAEs are instrumental in creating this compressed, feature-rich representation, known as the latent space.

1.2. From Autoencoders to VAEs: Addressing Limitations

To understand VAEs, it’s essential to first grasp the concept of autoencoders, as VAEs build upon their foundation.

1.2.1. Autoencoders: Dimensionality Reduction and Feature Extraction

Autoencoders are unsupervised learning neural networks primarily used for dimensionality reduction. In the context of image processing, training models like Stable Diffusion directly on high-resolution images (e.g., 4K images with millions of pixels) is computationally prohibitive. Autoencoders address this challenge by encoding high-dimensional image data into a lower-dimensional latent vector that retains essential features.

Autoencoder: A neural network architecture designed to learn compressed representations of data in an unsupervised manner. It consists of an encoder network that compresses the input data into a lower-dimensional latent space and a decoder network that reconstructs the original data from this latent representation.

  • Dimensionality Reduction: Autoencoders reduce the number of dimensions needed to represent data, making computations more efficient.
  • Feature Extraction: The latent vector learned by an autoencoder encapsulates the most salient features of the input data.

1.2.2. Limitations of Traditional Autoencoders for Generation

While autoencoders excel at compression and reconstruction, they have limitations when it comes to generating new, varied outputs. Traditional autoencoders are designed to memorize patterns and compress images into fixed latent vectors. They lack the inherent ability to generate diverse variations of the data they’ve been trained on. If asked to generate a specific image, they tend to reproduce something very close to the training examples, rather than creating novel variations.

1.2.3. VAEs: Introducing Probability and Variation

VAEs overcome the generative limitations of autoencoders. Unlike autoencoders that memorize patterns and map data points to fixed latent vectors, VAEs learn and predict probability distributions over the latent space.

  • Probability Distribution: Instead of a single point in latent space, each data point is encoded as a distribution, characterized by a mean and variance.
  • Latent Space as Pools: Imagine training a VAE on digits. Instead of each digit (e.g., ‘7’, ‘2’) being represented by a fixed point in latent space, VAEs create “pools” or distributions for each digit. This means that the latent representation for ‘7’ is not a single vector, but a range of vectors defined by a mean and variance, forming a cluster in the latent space.

By learning these probability distributions, VAEs gain the ability to generate variations of images. Because each class of images (like digits in the example) is represented by a distribution, sampling from these distributions allows the VAE to create new instances that are similar to, but not identical to, the training data.

2. Understanding the VAE Architecture and Workflow

The VAE architecture involves two main components: an encoder and a decoder. The training process is driven by a specific loss function that combines reconstruction accuracy and a regularization term.

2.1. VAE Architecture Components

  • Encoder: The encoder network takes an input image and maps it to a probability distribution in the latent space. Crucially, the encoder outputs two vectors: the mean (μ) and the standard deviation (σ) (or log variance, often used for numerical stability). These parameters define a Gaussian distribution for each input image.
  • Latent Space: The latent space is the compressed representation space learned by the VAE. In VAEs, this space is probabilistic. Instead of encoding each input to a single point, VAEs encode to a distribution in this space.
  • Decoder: The decoder network takes a sample from the latent distribution and attempts to reconstruct the original input image.

2.2. VAE Workflow: Encoding, Sampling, and Decoding

  1. Input Image: An image is fed into the encoder.
  2. Encoder Output (Mean and Log Variance): The encoder processes the image and outputs the mean (μ) and log variance (log(σ2)) vectors.
  3. Latent Vector Sampling (Reparameterization Trick): To enable backpropagation during training, a technique called the reparameterization trick is used. A latent vector (z) is sampled from the distribution defined by μ and σ. This is typically done by sampling from a standard normal distribution (ε ~ N(0, 1)) and then transforming it: z = μ + σ * ε.
  4. Decoder Input (Latent Vector): The sampled latent vector (z) is passed to the decoder.
  5. Reconstructed Image: The decoder attempts to reconstruct the original image from the latent vector.

2.3. Training the VAE: Loss Function

Training a VAE involves optimizing a loss function that encourages both accurate reconstruction and a well-structured latent space. The VAE loss function typically consists of two terms:

  • Reconstruction Loss: This term measures how well the decoder reconstructs the original input image from the latent vector. Common reconstruction loss functions include Mean Squared Error (MSE) loss or Binary Cross-Entropy loss, depending on the nature of the input data.

    Reconstruction Loss: A component of the VAE loss function that quantifies the difference between the original input data and the data reconstructed by the decoder. It encourages the VAE to learn to encode information necessary for accurate reconstruction.

  • KL Divergence (Kullback-Leibler Divergence): This term acts as a regularizer. It measures the difference between the learned latent distribution (defined by μ and σ) and a prior distribution, typically a standard normal distribution N(0, 1). KL divergence encourages the learned latent distributions to be close to the prior, promoting a smooth and continuous latent space, which is crucial for generation.

    KL Divergence (Kullback-Leibler Divergence): A measure of how one probability distribution diverges from a second, reference probability distribution. In VAEs, it is used to regularize the learned latent distribution, ensuring it is close to a chosen prior distribution (usually a standard Gaussian).

The total VAE loss is a weighted sum of the reconstruction loss and the KL divergence:

Total Loss = Reconstruction Loss + β * KL Divergence

Where β (beta) is a hyperparameter that controls the weight of the KL divergence term.

3. Implementing a VAE in PyTorch

This section details the implementation of a VAE using PyTorch, focusing on key modules like self-attention, attention blocks, visual blocks, encoder, and decoder.

3.1. Setting up the Environment and Importing Libraries

To begin, ensure you have PyTorch and necessary libraries installed. A Google Colab notebook with GPU runtime is recommended for efficient training.

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import os

3.2. Self-Attention Mechanism

Self-attention is a crucial component, particularly in advanced architectures. It allows the model to weigh the importance of different parts of the input when processing it.

Self-Attention: A mechanism in neural networks that allows a model to attend to different positions within the same input sequence to compute a representation of the sequence. It is particularly effective in capturing long-range dependencies and relationships within data.

3.2.1. Understanding Self-Attention

Self-attention computes attention scores to determine the relationships between different elements within a sequence (e.g., pixels in an image, tokens in a sentence). It’s based on the “Attention is All You Need” paper and involves calculating Query (Q), Key (K), and Value (V) matrices.

3.2.2. Implementing Self-Attention Module

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.projection = nn.Linear(embed_dim, embed_dim * 3) # For Q, K, V
        self.out_projection = nn.Linear(embed_dim, embed_dim)
        self.d_heads = embed_dim // num_heads

    def forward(self, x):
        batch_size, seq_len, dim = x.shape
        interim_shape = (batch_size, seq_len, self.num_heads, self.d_heads)

        q, k, v = self.projection(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(interim_shape).transpose(1, 2), (q, k, v)) # Reshape for multi-head attention

        weights = torch.matmul(q, k.transpose(-1, -2)) * (self.d_heads ** -0.5) # Scaled dot-product attention
        weights = F.softmax(weights, dim=-1)
        attention = torch.matmul(weights, v)

        attention = attention.transpose(1, 2).reshape(batch_size, seq_len, dim) # Reshape back
        return self.out_projection(attention)
  • embed_dim: The embedding dimension or feature dimension of the input.
  • num_heads: The number of attention heads for parallel processing.
  • projection layer: A linear layer to project the input into Q, K, and V matrices.
  • out_projection layer: A linear layer to project the output back to the embedding dimension.
  • d_heads: Dimension per head, calculated as embed_dim // num_heads.
  • Forward Pass:
    • Projects input x to Q, K, V.
    • Reshapes and transposes for multi-head processing.
    • Calculates attention weights using scaled dot-product attention.
    • Applies softmax to get attention scores.
    • Combines weights and values to get attention output.
    • Reshapes and projects back to the original dimension.

3.3. Attention Block

The attention block integrates the self-attention mechanism with residual connections and group normalization to create a more robust and effective layer.

Residual Connection (Skip Connection): A technique in deep neural networks where the input of a layer is added to its output. This helps to mitigate the vanishing gradient problem and enables the training of deeper networks.

Group Normalization: A normalization technique that divides channels into groups and normalizes within each group. It is less sensitive to batch size than batch normalization and often performs better in generative models.

class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels) # Group normalization layer
        self.attention = SelfAttention(channels, num_heads) # Self-attention module
        self.residual_scale = nn.Parameter(torch.ones(1))

    def forward(self, x):
        residual = x.clone() # Residual connection
        x = self.groupnorm(x)
        batch_size, channels, height, width = x.shape
        x = x.view(batch_size, channels, height * width).transpose(1, 2) # Reshape for attention (Batch, Sequence, Channels)
        x = self.attention(x)
        x = x.transpose(1, 2).view(batch_size, channels, height, width) # Reshape back
        return x + residual * self.residual_scale # Add residual
  • groupnorm: Group normalization layer for stabilizing training.
  • attention: Self-attention module.
  • residual_scale: Learnable parameter to scale the residual connection.
  • Forward Pass:
    • Applies group normalization.
    • Reshapes input to (Batch, Sequence, Channels) format suitable for self-attention.
    • Applies self-attention.
    • Reshapes back to original image format.
    • Adds residual connection with scaling.

3.4. Visual Block

The visual block is a fundamental building block used in both the encoder and decoder. It typically involves convolutional layers and normalization.

class VisualBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None):
        super().__init__()
        if out_channels is None:
            out_channels = in_channels
        self.groupnorm1 = nn.GroupNorm(32, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # Convolutional layer
        self.groupnorm2 = nn.GroupNorm(32, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.residual_scale = nn.Parameter(torch.ones(1))
        self.identity = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1) # Identity or 1x1 Conv for residual

    def forward(self, x):
        residual = self.identity(x) # Identity mapping for residual if channels match, else 1x1 conv
        x = self.groupnorm1(x)
        x = F.silu(x) # Activation function (SiLU - Sigmoid Linear Unit)
        x = self.conv1(x)
        x = self.groupnorm2(x)
        x = F.silu(x)
        x = self.conv2(x)
        return x + residual * self.residual_scale # Add residual
  • groupnorm1, groupnorm2: Group normalization layers.
  • conv1, conv2: Convolutional layers with kernel size 3 and padding 1 (shape-preserving).
  • residual_scale: Learnable parameter for residual scaling.
  • identity: Identity layer or 1x1 convolution for residual path, to handle channel changes.
  • Forward Pass:
    • Applies group normalization, SiLU activation, and convolution twice.
    • Adds residual connection, either identity or via 1x1 convolution if channel dimensions differ.

3.5. Encoder Implementation

The encoder network progressively reduces the spatial dimensions of the input image while increasing the number of channels, extracting features and compressing the information into a latent representation.

class Encoder(nn.Module):
    def __init__(self, channels=3, latent_dim=4):
        super().__init__()
        self.conv_in = nn.Conv2d(channels, 32, kernel_size=3, padding=1) # Initial convolution
        self.blocks = nn.ModuleList([
            VisualBlock(32), # Residual blocks
            VisualBlock(32),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # Strided convolution for downsampling (h/2, w/2)
            VisualBlock(64),
            VisualBlock(64),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # Downsampling (h/4, w/4)
            VisualBlock(128),
            VisualBlock(128),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), # Downsampling (h/8, w/8)
            VisualBlock(128),
            AttentionBlock(128), # Attention block
            VisualBlock(128),
        ])
        self.norm_out = nn.GroupNorm(32, 128) # Final normalization
        self.conv_out = nn.Conv2d(128, 2 * latent_dim, kernel_size=3, padding=1) # Output convolution for mean and log variance

    def forward(self, x):
        x = self.conv_in(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm_out(x)
        x = F.silu(x)
        x = self.conv_out(x)
        mean, log_var = x.chunk(2, dim=1) # Split output into mean and log variance
        return mean, log_var
  • conv_in: Initial convolutional layer to increase channels.
  • blocks: Sequential list of VisualBlocks, strided convolutions (for downsampling), and AttentionBlock.
  • norm_out: Final group normalization.
  • conv_out: Output convolution to produce mean and log variance (twice the latent dimension).
  • Forward Pass:
    • Applies initial convolution.
    • Iterates through blocks, applying residual blocks, downsampling convolutions, and attention block.
    • Applies normalization and final convolution.
    • Chunks the output into mean and log variance tensors.

3.6. Decoder Implementation

The decoder network performs the reverse operation of the encoder. It takes a latent vector and upsamples it back to the original image dimensions, reconstructing the image.

class Decoder(nn.Module):
    def __init__(self, channels=3, latent_dim=4):
        super().__init__()
        self.conv_in = nn.Conv2d(latent_dim, 128, kernel_size=3, padding=1) # Input convolution from latent space
        self.blocks = nn.ModuleList([
            VisualBlock(128), # Residual blocks
            AttentionBlock(128), # Attention block
            VisualBlock(128),
            nn.Upsample(scale_factor=2), # Upsampling (x2) - h/8 -> h/4
            VisualBlock(128),
            VisualBlock(128),
            nn.Upsample(scale_factor=2), # Upsampling (x2) - h/4 -> h/2
            VisualBlock(128),
            VisualBlock(64),
            nn.Upsample(scale_factor=2), # Upsampling (x2) - h/2 -> h
            VisualBlock(64),
            VisualBlock(32),
            VisualBlock(32),
        ])
        self.norm_out = nn.GroupNorm(32, 32) # Final normalization
        self.conv_out = nn.Conv2d(32, channels, kernel_size=3, padding=1) # Output convolution to original channels

    def forward(self, x):
        x = self.conv_in(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm_out(x)
        x = F.silu(x)
        x = self.conv_out(x)
        return x
  • conv_in: Initial convolution layer to increase channels from latent dimension.
  • blocks: Sequential list of VisualBlocks, AttentionBlock, and nn.Upsample layers (for upsampling).
  • norm_out: Final group normalization.
  • conv_out: Output convolution to produce the reconstructed image with original channels.
  • Forward Pass:
    • Applies initial convolution from latent space.
    • Iterates through blocks, applying residual blocks, attention block, and upsampling layers to progressively increase spatial dimensions.
    • Applies normalization and final convolution to reconstruct the image.

3.7. VAE Model Class

This class encapsulates the encoder and decoder and implements the reparameterization trick and the forward pass for the entire VAE.

class VAE(nn.Module):
    def __init__(self, channels=3, latent_dim=4):
        super().__init__()
        self.encoder = Encoder(channels, latent_dim) # Encoder module
        self.decoder = Decoder(channels, latent_dim) # Decoder module
        self.latent_dim = latent_dim

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var) # Standard deviation from log variance
        eps = torch.randn_like(std) # Sample from standard normal distribution
        return mean + std * eps # Reparameterization trick

    def forward(self, x):
        mean, log_var = self.encoder(x) # Encode to get mean and log variance
        z = self.reparameterize(mean, log_var) # Sample latent vector
        x_recon = self.decoder(z) # Decode latent vector to reconstructed image
        return x_recon, mean, log_var # Return reconstructed image, mean, and log variance
  • encoder: Encoder module.
  • decoder: Decoder module.
  • latent_dim: Latent dimension.
  • reparameterize(mean, log_var): Implements the reparameterization trick to sample latent vectors.
  • forward(x):
    • Encodes input x to get mean and log variance.
    • Reparameterizes to sample latent vector z.
    • Decodes z to get reconstructed image x_recon.
    • Returns reconstructed image, mean, and log variance.

4. Training the VAE

This section outlines the process of training the VAE, including data loading, loss function definition, optimization, and training loop implementation.

4.1. Data Loading and Preprocessing

Download and preprocess the dataset. For example, a dataset of dog images can be used. Split the dataset into training and testing sets. Use torch.utils.data.DataLoader for efficient data loading.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define data transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
])

# Load dataset (replace with your dataset path)
dataset = datasets.ImageFolder(root='./data/dogs', transform=transform) # Assuming dataset is in './data/dogs'

# Split into training and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

4.2. Loss Function Definition

Define the VAE loss function, combining reconstruction loss (e.g., MSE Loss) and KL divergence.

def vae_loss_function(recon_x, x, mean, log_var, beta=1.0):
    recon_loss = F.mse_loss(recon_x, x, reduction='mean') # Reconstruction loss (MSE)
    kl_divergence = -0.5 * torch.mean(1 + log_var - mean.pow(2) - log_var.exp()) # KL Divergence
    return recon_loss + beta * kl_divergence # Total VAE loss
  • recon_loss: Mean Squared Error between reconstructed image and original image.
  • kl_divergence: KL divergence term to regularize latent space.
  • beta: Weight for KL divergence term.

4.3. Optimizer and Training Loop

Initialize the VAE model, optimizer (e.g., Adam), and implement the training loop.

# Initialize model, optimizer, and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # Adam optimizer

num_epochs = 20
train_losses = []

for epoch in range(num_epochs):
    model.train() # Set model to training mode
    epoch_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device) # Move data to device
        optimizer.zero_grad() # Zero gradients
        recon_batch, mean, log_var = model(data) # Forward pass
        loss = vae_loss_function(recon_batch, data, mean, log_var) # Calculate loss
        loss.backward() # Backward pass
        optimizer.step() # Optimizer step
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_epoch_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_epoch_loss:.4f}')

    # (Optional) Save reconstructed images or model checkpoints here
  • Device: Set to CUDA if available, otherwise CPU.
  • Model and Optimizer: Initialize VAE model and Adam optimizer.
  • Training Loop:
    • Iterate through epochs.
    • Set model to train mode (model.train()).
    • Iterate through data loader.
    • Move data to device.
    • Zero gradients (optimizer.zero_grad()).
    • Forward pass (model(data)).
    • Calculate loss (vae_loss_function(...)).
    • Backward pass (loss.backward()).
    • Optimizer step (optimizer.step()).
    • Accumulate epoch loss.
    • Print average epoch loss.
    • (Optional) Save reconstructed images or model checkpoints.

5. Integrating Custom VAE into Stable Diffusion

This section briefly touches on how a trained custom VAE can be integrated into a Stable Diffusion pipeline.

5.1. Creating a Diffusers-Compatible VAE

To integrate the custom VAE with libraries like Diffusers (used for Stable Diffusion), you might need to create a wrapper or a compatible class that conforms to the expected interface. This may involve defining specific functions like encode, decode, and load_pretrained_weights to align with the Diffusers pipeline requirements.

5.2. Replacing the Default VAE in Stable Diffusion

Once you have a Diffusers-compatible VAE, you can replace the default VAE component in a Stable Diffusion pipeline with your custom-trained VAE. This typically involves loading the Stable Diffusion pipeline and then injecting your custom VAE into the pipeline’s VAE component.

5.3. Inference with Custom VAE

After integration, you can run inference with the modified Stable Diffusion pipeline using your custom VAE. This allows you to generate images using the Stable Diffusion model but with the latent space representation learned by your VAE.

Note: Integrating a custom-trained VAE into a complex pipeline like Stable Diffusion can be intricate and might require adjustments to ensure compatibility and optimal performance. The provided transcript briefly touches on this and suggests further exploration and experimentation.

6. Conclusion

This chapter provided a detailed walkthrough of building a Variational Autoencoder from scratch using PyTorch. It covered the theoretical underpinnings, architectural components, implementation of key modules like self-attention, and the training process. While integrating a custom VAE into advanced models like Stable Diffusion requires further steps, this chapter lays a strong foundation for understanding and implementing VAEs for generative tasks. Further experimentation with training parameters, architectures, and datasets will lead to deeper insights and improved generative capabilities.