Implementing a Convolutional Autoencoder with PyTorch : Aditya Sharma

Implementing a Convolutional Autoencoder with PyTorch
by: Aditya Sharma
blow post content copied from  PyImageSearch
click here to view original post



Table of Contents


Implementing a Convolutional Autoencoder with PyTorch

In this tutorial, we will walk you through training a convolutional autoencoder utilizing the widely used Fashion-MNIST dataset. We will then explore different testing situations (e.g., visualizing the latent space, uniform sampling of data points from this latent space, and recreating images using these sampled points).

We’re about to dive deep into this tutorial. But first things first — you’ll need to access our dataset. We could have hosted it anywhere, but we chose Roboflow and for good reasons!

Let’s rewind a bit. If we got a nickel every time a dataset disappeared from the web, we’d have enough to buy a Tesla. And oh, the frustration! Datasets disappear faster than a plate of hot cookies at a tech meetup (we’re still salty about the LISA dataset, by the way, 😠).

Roboflow swooped in and saved the day, like Batman but for datasets. It keeps our datasets safe, available, and hassle-free. So, it’s not just us having your back; Roboflow has yours too.

Ready to check out the Fashion-MNIST dataset? All you need is a Roboflow account. It’s free, easy to create, and won’t demand your firstborn in return. Think of it as your all-access pass to our tutorial.

Pause momentarily, tap into your inner data scientist, and register for your no-strings-attached Roboflow account.

Yes, I’m in — I’ll Register Now

Upon completing this tutorial, you will be well-equipped with the knowledge required to implement and train convolutional autoencoders using PyTorch. Moreover, you will gain valuable insights into the capabilities and limitations of convolutional autoencoders.

Let’s embark on this thrilling journey to explore the power of autoencoders with PyTorch!

This lesson is the 2nd of a 4-part series on Autoencoders:

  1. Introduction to Autoencoders
  2. Implementing a Convolutional Autoencoder with PyTorch (this tutorial)
  3. Lesson 3
  4. Lesson 4

To learn to train convolutional autoencoders in PyTorch with post-training embedding analysis on the Fashion-MNIST dataset, just keep reading.


Looking for the source code to this post?

Jump Right To The Downloads Section

Configuring Your Development Environment

To follow this guide, you need to have torch, torchvision, tqdm, and matplotlib libraries installed on your system.

Luckily, all these libraries are pip-installable:

$ pip install torch>=2.0.0
$ pip install torchvision>=0.15.0
$ pip install tqdm==4.65.0
$ pip install matplotlib==3.3.2

Need Help Configuring Your Development Environment?

Need help configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in minutes.
Figure 1: Need help configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in minutes.

All that said, are you:

  • Short on time?
  • Learning on your employer’s administratively locked system?
  • Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
  • Ready to run the code immediately on your Windows, macOS, or Linux system?

Then join PyImageSearch University today!

Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.

And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!


Project Structure

We first need to review our project directory structure.

Start by accessing this tutorial’s “Downloads” section to retrieve the source code and example images.

From there, take a look at the directory structure:

$ tree .
.
├── output
│   ├── embedding_visualize.png
│   ├── image_grid_on_embeddings.png
│   ├── model_weights
│   │   └── best_autoencoder.pt
│   ├── real_test_images_after_train.png
│   ├── real_test_images_before_train.png
│   ├── reconstruct_after_train.png
│   ├── reconstruct_before_train.png
│   └── training_progress
│       ├── epoch10_test_recon.png
│       ├── epoch1_test_recon.png
│       ├── epoch2_test_recon.png
│       ├── epoch3_test_recon.png
│       ├── epoch4_test_recon.png
│       ├── epoch5_test_recon.png
│       ├── epoch6_test_recon.png
│       ├── epoch7_test_recon.png
│       ├── epoch8_test_recon.png
│       └── epoch9_test_recon.png
├── pyimagesearch
│   ├── __init__.py
│   ├── config.py
│   ├── network.py
│   └── utils.py
├── test.py
└── train.py

4 directories, 23 files

In the pyimagesearch directory, we have the following files:

  • config.py: This configuration file is for training the autoencoder.
  • network.py: Hosts the convolutional autoencoder implementation.
  • utils.py: This file contains utilities for post-training autoencoder analysis and a validation method for evaluating the autoencoder during training.

In the core directory, we have the following:

  • test.py: This inference script evaluates the trained autoencoder on the test dataset and conducts post-training analysis.
  • train.py: This training script trains the vanilla autoencoder on the Fashion-MNIST dataset.
  • output: This folder hosts the model weights, training reconstruction progress over each epoch, evaluation of the test set, and post-training analysis of the autoencoder.

About the Dataset

In this tutorial, we employ the Fashion-MNIST dataset for training our autoencoder model.


Overview

Fashion-MNIST is a dataset of Zalando’s article images consisting of the following:

  • training set of 60,000 examples
  • test set of 10,000 examples

Each sample is a 28x28 grayscale image associated with a label from 10 classes (Figure 2). It serves as a direct drop-in replacement for the original Fashion-MNIST dataset for benchmarking machine learning algorithms, with the benefit of being more representative of the actual data tasks and challenges.

Figure 2: Sample images from the Fashion-MNIST dataset (source: image by the author).

Class Distribution

The Fashion-MNIST dataset is balanced, which means it has an equal number of samples from each class. The 10 classes are T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, and Ankle boot. Each class has 6,000 images in the training set and 1,000 in the test set.


Data Preprocessing

Before training the autoencoder, the images from the dataset are preprocessed. Each image in the dataset is a 28x28 grayscale image. The pixel values fall in the range of 0 to 255. As a preprocessing step, these pixel values are normalized to fall from 0 to 1. This is achieved by dividing each pixel value by 255. This normalization helps in faster and more stable convergence during training.


Data Split

The dataset is split into two parts: a training set and a test set. The training set, which contains 60,000 images, is used to train the autoencoder, and the test set, which includes 10,000 images, is used to evaluate the model’s performance. It is essential to separate the data used for training from the data used for testing to get an unbiased measure of the model’s performance.


Configuring the Prerequisites

Before we start our implementation, let’s review our project’s configuration. For that, we will move on to the config.py script located in the pyimagesearch directory.

The config.py script sets up the autoencoder model hyperparameters and creates an output directory for storing training progress metadata, model weights, and post-training analysis plots. It also defines the class labels dictionary mapping from integer to human-readable format.

# import the necessary packages
import os

import torch

# set device to 'cpu' or 'cuda' (GPU) based on availability
# for model training and testing
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# define model hyperparameters
LR = 0.001
PATIENCE = 2
IMAGE_SIZE = 32
CHANNELS = 1
BATCH_SIZE = 64
EMBEDDING_DIM = 2
EPOCHS = 10

# create output directory
output_dir = "output"
os.makedirs("output", exist_ok=True)

Lines 2-4 import the os module, which provides functionality for operating system-dependent operations, and the torch module, a widely used deep learning framework.

On Line 8, we check if CUDA is available on our machine. If CUDA is available, the code will set DEVICE to cuda, and PyTorch will perform its computations on the GPU, which can drastically speed up training for many machine learning models. If CUDA is not available, DEVICE will be set to cpu, and PyTorch will use the CPU for its computations.

Then from Lines 11-17, the following model hyperparameters are defined:

  • LR is the learning rate for the model, which influences how much the model changes in response to the estimated error each time the model weights are updated.
  • PATIENCE might be used in early stopping during model training where training is stopped when performance on a validation dataset does not improve for PATIENCE (in this case, PATIENCE is set to 2) consecutive epochs.
  • IMAGE_SIZE defines the height and width of the input images that the model will be trained on, which in this case are 32x32 pixels.
  • CHANNELS represents the number of color channels in the images. In this case, CHANNELS is set to 1, which suggests that the images will be grayscale. If CHANNELS were 3, that would suggest the images are in full color (red, green, blue).
  • BATCH_SIZE is the number of training examples utilized in one iteration. In this case, the model will look at 64 images at a time before updating its weights.
  • EMBEDDING_DIM is the size of the embedding space, and it’s commonly used in models like autoencoders or embedding layers. In this case, it is set to 2.
  • EPOCHS is the number of complete passes through the entire training dataset. The model will be trained over the whole dataset 10 times.

On Lines 20 and 21, an output directory is created where the results from the model (e.g., saved model weights or performance plots) are stored. The os.makedirs function creates the directory specified by the first argument. The exist_ok=True argument means that if the directory already exists, the function won’t raise an error and will do nothing.

# create the training_progress directory inside the output directory
training_progress_dir = os.path.join(output_dir, "training_progress")
os.makedirs(training_progress_dir, exist_ok=True)

# create the model_weights directory inside the output directory
# for storing autoencoder weights
model_weights_dir = os.path.join(output_dir, "model_weights")
os.makedirs(model_weights_dir, exist_ok=True)

On Line 24, the os.path.join(output_dir, "training_progress") function creates a file path that includes the output_dir and a new directory called training_progress. This new path is stored in the variable training_progress_dir. The os.makedirs function is then used to create this new directory. Again, exist_ok=True means that the function won’t throw an error if the directory already exists.

This directory would store files related to the model’s training progress (e.g., reconstruction plots).

# define model_weights, reconstruction & real before training images path
MODEL_WEIGHTS_PATH = os.path.join(model_weights_dir, "best_autoencoder.pt")
FILE_RECON_BEFORE_TRAINING = os.path.join(output_dir, "reconstruct_before_train.png")
FILE_REAL_BEFORE_TRAINING = os.path.join(
    output_dir, "real_test_images_before_train.png"
)

# define reconstruction & real after training images path
FILE_RECON_AFTER_TRAINING = os.path.join(output_dir, "reconstruct_after_train.png")
FILE_REAL_AFTER_TRAINING = os.path.join(output_dir, "real_test_images_after_train.png")

# define latent space and image grid embeddings plot path
LATENT_SPACE_PLOT = os.path.join(output_dir, "embedding_visualize.png")
IMAGE_GRID_EMBEDDINGS_PLOT = os.path.join(output_dir, "image_grid_on_embeddings.png")

# define class labels dictionary
CLASS_LABELS = {
    0: "T-shirt/top",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle boot",
}

Line 33 defines the path where the best autoencoder weights will be saved as a .pt (PyTorch) file. This is done so you can load the trained model later without retraining it.

Then, we define FILE_RECON_BEFORE_TRAINING and FILE_REAL_BEFORE_TRAINING: these are paths where images will be saved before training the model. The images (or plots) are initial reconstructions from an untrained model and the corresponding real images on Lines 34-37.

On Lines 40 and 41, we define FILE_RECON_AFTER_TRAINING and FILE_REAL_AFTER_TRAINING: these are paths where images will be saved after training the model. The images (or plots) are reconstructions from the trained model and the corresponding real images.

Then on Line 44, we define the path for LATENT_SPACE_PLOT: this is the path where a plot of the embeddings in the latent space will be saved. This would be a 2D plot since the EMBEDDING_DIM is 2.

The IMAGE_GRID_EMBEDDINGS_PLOT path is defined where a plot of the image grid on embeddings will be saved on Line 45. This is a visualization where each point in the 2D latent space corresponds to an image, showing how the model groups similar images together.

Finally, from Lines 48-59, the CLASS_LABELS dictionary maps class labels (integers from 0 to 9) to the corresponding class names. This is useful when the model makes predictions: it might output an integer label, and you can use this dictionary to map the integer to the corresponding class name for the Fashion-MNIST dataset.


Defining the Utilities

Now that the configuration has been defined, we can determine the utilities for validating the autoencoder during training and post-training analysis plots. The utils.py script defines several functions:

  • extract_random_images to randomly select a set of random images and their corresponding labels from a PyTorch DataLoader object
  • display_images to display a grid of images
  • display_random_images is used to extract a random subset of images from a DataLoader (using extract_random_images) and potentially apply transformations (via an encoder and decoder) before displaying them.
  • validate function evaluates the autoencoder after every epoch.
  • get_test_embeddings method leverages a trained decoder model to extract embeddings from images.
  • plot_latent_space plots the latent space of the trained encoder model using test data.
  • get_random_test_images_embeddings produces embeddings from a random set of test images using an encoder model.
  • plot_image_grid_on_embeddings visualizes how the encoder has learned to represent the images from the test dataset in latent space and how these encodings are reconstructed back to the original image space by the decoder.

Extracting Random Images

# import the necessary packages
import matplotlib
import numpy as np
import torch
import torchvision

from pyimagesearch import config

matplotlib.use("agg")
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from tqdm import tqdm

We start by importing several necessary packages like the following:

  • matplotlib: For creating static, animated, and interactive visualizations in Python. The matplotlib.use("agg") line sets the backend of matplotlib to the ‘agg’ backend, which is a backend used for rendering into a raster format like a PNG file.
  • numpy: It supports arrays and a collection of mathematical functions to operate on these arrays.
  • torch: The most important library that helps create and train the autoencoder model.
  • torchvision: This is a part of PyTorch, consisting of popular datasets, model architectures, and common image transformations for computer vision.
  • config: This module contains various configuration parameters for our project.
  • matplotlib.cm: This is a module for colormap handling utilities. Colormaps are used in Matplotlib to map normalized data values to colors.
  • matplotlib.colors: This module provides classes for converting number or color arguments to RGB or RGBA.
  • matplotlib.pyplot: This module is a state-based interface to matplotlib and provides a MATLAB-like interface.
  • matplotlib.offsetbox: This module provides classes for creating a box around an image and creating an annotation box. It is useful for adding more detailed images or labels to a plot.
  • tqdm: This is a fast, extensible progress bar that we will use to track the autoencoder training progress for each epoch.
def extract_random_images(data_loader, num_images):
    # initialize empty lists to store all images and labels
    all_images = []
    all_labels = []

    # iterate through the data loader to get images and labels
    for images, labels in data_loader:
        # append the current batch of images and labels to the respective lists
        all_images.append(images)
        all_labels.append(labels)
        # stop the iteration if the total number of images exceeds 1000
        if len(all_images) * data_loader.batch_size > 1000:
            break

    # concatenate all the images and labels tensors along the 0th dimension
    all_images = torch.cat(all_images, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # generate random indices for selecting a subset of images and labels
    random_indices = np.random.choice(len(all_images), num_images, replace=False)
    # use the random indices to extract the corresponding images and labels
    random_images = all_images[random_indices]
    random_labels = all_labels[random_indices]

    # return the randomly selected images and labels to the calling function
    return random_images, random_labels

The extract_random_images function randomly selects a certain number of images and their corresponding labels from a PyTorch DataLoader object. Let’s break down the function line-by-line.

On Lines 19 and 20, we initialize two empty lists, all_images and all_labels, to store the images and labels from the DataLoader:

Then from Lines 23-29 we,

  • Iterate through the DataLoader, which yields batches of images and labels.
  • Append each batch of images and labels to the respective lists.
  • Stop the iteration if the total number of images exceeds 1000:

On Lines 32 and 33, we concatenate all the image and label tensors along the 0th dimension (i.e., the batch size dimension).

Next, we generate some random indices using num_images, which will be used to select a random subset of images and labels on Line 36.

Once we have the random_indices, we select the corresponding images and labels on Lines 38 and 39.

Finally, we return the randomly selected images and labels on Line 42.


Displaying Images

def display_images(images, labels, num_images_per_row, title, filename=None, show=True):
    # calculate the number of rows needed to display all the images
    num_rows = len(images) // num_images_per_row

    # create a grid of images using torchvision's make_grid function
    grid = torchvision.utils.make_grid(
        images.cpu(), nrow=num_images_per_row, padding=2, normalize=True
    )
    # convert the grid to a NumPy array and transpose it to
    # the correct dimensions
    grid_np = grid.numpy().transpose((1, 2, 0))

    # create a new figure with the appropriate size
    plt.figure(figsize=(num_images_per_row * 2, num_rows * 2))
    # show the grid of images
    plt.imshow(grid_np)
    # remove the axis ticks
    plt.axis("off")
    # set the title of the plot
    plt.title(title, fontsize=16)

The display_images method displays a grid of images with a specific title. First, we calculate the number of rows to show all the images on Line 47.

Next, we create a grid of images using torchvision’s make_grid function on Lines 50-52. The make_grid function takes a 4D mini-batch Tensor of shape (B x C x H x W) and makes a grid of images. nrow is the number of images per row. padding is the amount of padding. normalize=True will shift/resize the images to the range of (0, 1).

On Line 55, we convert the grid to a NumPy array and transpose it to the correct dimensions. This conversion to array is required because PyTorch images are in (C x H x W) format, but matplotlib requires images in (H x W x C) format.

From Lines 58-64, we

  • create a new figure with the appropriate size
  • show the grid of images
  • remove the axis ticks
  • set the title of the plot
    # add labels for each image in the grid
    for i in range(len(images)):
        # calculate the row and column of the current image in the grid
        row = i // num_images_per_row
        col = i % num_images_per_row
        # get the name of the label for the current image
        label_name = config.CLASS_LABELS[labels[i].item()]
        # add the label name as text to the plot
        plt.text(
            col * (images.shape[3] + 2) + images.shape[3] // 2,
            (row + 1) * (images.shape[2] + 2) - 5,
            label_name,
            fontsize=12,
            ha="center",
            va="center",
            color="white",
            bbox=dict(facecolor="black", alpha=0.5, lw=0),
        )

    # if show is True, display the plot
    if show:
        plt.show()
    else:
        # otherwise, save the plot to a file and close the figure
        plt.savefig(filename, bbox_inches="tight")
        plt.close()

From Lines 67-83, we add labels for each image in the grid:

  • For each image, it calculates the row and column of the image in the grid.
  • It retrieves the image’s label from the config.CLASS_LABELS dictionary using its label as the key.
  • It then uses plt.text to add the label’s name as text to the plot at the calculated coordinates. The text is white, centered, and placed on a semi-transparent black background for readability.

Finally, we close the function on Lines 86-91 with a condition that checks:

  • If show is True, the function displays the plot.
  • Otherwise, we save the plot to a file specified by filename and close the figure.

This allows you to either directly visualize the grid of images or save it to the disk for later use.


Displaying Random Images

def display_random_images(
    data_loader,
    encoder=None,
    decoder=None,
    file_recon=None,
    file_real=None,
    title_recon=None,
    title_real=None,
    display_real=True,
    num_images=32,
    num_images_per_row=8,
):
    # extract a random subset of images and labels from the data loader
    random_images, random_labels = extract_random_images(data_loader, num_images)

The display_random_images function extracts a random subset of images from a DataLoader and potentially applies transformations (via an encoder and decoder) before displaying them.

Let’s understand the parameters the function accepts:

  • data_loader: A DataLoader object that yields batches of images and labels.
  • encoder and decoder: Optional PyTorch models or functions that transform the images. If provided, the images will be passed through the encoder and decoder before being displayed.
  • file_recon and file_real: Optional filenames to save the reconstructed and real images.
  • title_recon and title_real: Optional titles for the reconstructed and real image plots.
  • display_real: A boolean determining whether to display the real images.
  • num_images: The number of images to extract from the DataLoader.
  • num_images_per_row: The number of images to display per row in the plot.

On Line 107, we call the extract_random_images function to extract a random subset of images and their corresponding labels from the DataLoader.

    # if an encoder and decoder are provided,
    # use them to generate reconstructions
    if encoder is not None and decoder is not None:
        # set the encoder and decoder to evaluation mode
        encoder.eval()
        decoder.eval()
        # move the random images to the appropriate device
        random_images = random_images.to(config.DEVICE)
        # generate embeddings for the random images using the encoder
        random_embeddings = encoder(random_images)
        # generate reconstructions for the random images using the decoder
        random_reconstructions = decoder(random_embeddings)
        # display the reconstructed images
        display_images(
            random_reconstructions.cpu(),
            random_labels,
            num_images_per_row,
            title_recon,
            file_recon,
            show=False,
        )
        # if specified, also display the original images
        if display_real:
            display_images(
                random_images.cpu(),
                random_labels,
                num_images_per_row,
                title_real,
                file_real,
                show=False,
            )
    # if no encoder and decoder are provided, simply display the original images
    else:
        display_images(
            random_images, random_labels, num_images_per_row, title="Real Images"
        )

From Lines 111-129, we check if both encoder and decoder are provided (we use them to generate reconstructions of the images):

  • We set the encoder and decoder to evaluation mode, necessary if the models contain layers like dropout or batch normalization that behave differently during training and evaluation.
  • Then move the randomly selected images to the device specified in the config (either a CPU or a GPU).
  • Use the encoder to generate embeddings for the images and the decoder to generate reconstructions of the images.
  • Finally, leverage the display_images function to display the reconstructed images and, optionally, the original images. If filenames are provided, it saves the plots to these files.

Else on Lines 141-144, if no encoder and decoder is provided, we simply display the original images using the display_images function.


Validating Test Data

def validate(encoder, decoder, test_loader, criterion):
    # set the encoder and decoder to evaluation mode
    encoder.eval()
    decoder.eval()

    # initialize the running loss to 0.0
    running_loss = 0.0

    # disable gradient calculation during validation
    with torch.no_grad():
        # iterate through the test loader
        for batch_idx, (data, _) in tqdm(
            enumerate(test_loader), total=len(test_loader)
        ):
            # move the data to the appropriate device CPU/GPU
            data = data.to(config.DEVICE)
            # encode the data using the encoder
            encoded = encoder(data)
            # decode the encoded data using the decoder
            decoded = decoder(encoded)
            # calculate the loss between the decoded and original data
            loss = criterion(decoded, data)
            # add the loss to the running loss
            running_loss += loss.item()

    # calculate the average loss over all batches
    # and return to the calling function
    return running_loss / len(test_loader)

The validate function is used to evaluate the performance of an encoder-decoder model (often used in autoencoders) on a test dataset.

On Lines 149 and 150, we set the encoder and decoder to evaluation mode. This is necessary because some layers in PyTorch models, such as dropout or batch normalization, behave differently during training and evaluation.

Then, on Line 153, initialize the running loss to 0.0. This will accumulate the loss for each batch in the test dataset.

At Line 156, we disable gradient calculation because it is not necessary during evaluation and can help save memory.

Lines 158-170 iterate over all batches in the test dataset. For each batch:

  • Move the data to the device specified in the config (either a CPU or a GPU).
  • Use the encoder to generate embeddings for the data and the decoder to reconstruct the original data from these embeddings.
  • Calculate the loss between the reconstructed and original data using the provided criterion.
  • Add this loss to the running loss.

Finally, on Line 174, after iterating over all test batches, calculate the average loss by dividing the running loss by the number of batches. This gives the mean loss per batch, which is then returned to the calling function.


Getting Test Embeddings

def get_test_embeddings(test_loader, encoder):
    # switch the model to evaluation mode
    encoder.eval()

    # initialize empty lists to store the embeddings and labels
    points = []
    label_idcs = []

    # iterate through the test loader
    for i, data in enumerate(test_loader):
        # move the images and labels to the appropriate device
        img, label = [d.to(config.DEVICE) for d in data]
        # encode the test images using the encoder
        proj = encoder(img)
        # convert the embeddings and labels to NumPy arrays
        # and append them to the respective lists
        points.extend(proj.detach().cpu().numpy())
        label_idcs.extend(label.detach().cpu().numpy())
        # free up memory by deleting the images and labels
        del img, label

    # convert the embeddings and labels to NumPy arrays
    points = np.array(points)
    label_idcs = np.array(label_idcs)

    # return the embeddings and labels to the calling function
    return points, label_idcs

The get_test_embeddings function generates and collects the embeddings for all the images in a test dataset using an encoder model.

We start by setting the encoder to evaluation mode on Line 179. As discussed before, this is necessary because some layers in PyTorch models, such as dropout or batch normalization, behave differently during training and evaluation.

Initialize two empty lists, points and label_idcs, to store the embeddings and labels of the test images on Lines 182 and 183.

From Lines 186-196, we iterate over all the batches in the test dataset. For each batch:

  • Move the images and labels to the device specified in the config (either a CPU or a GPU).
  • Use the encoder to generate embeddings for the images.
  • Convert the embeddings and labels to NumPy arrays and extend them to their respective lists.
  • Delete the images and labels to free up memory.

On Lines 199 and 200, convert the lists of embeddings and labels to NumPy arrays.

Finally, on Line 203, we return the embeddings and labels to the calling function.


Visualizing Latent Space

def plot_latent_space(test_loader, encoder, show=False):
    # get the embeddings and labels for the test images
    points, label_idcs = get_test_embeddings(test_loader, encoder)

    # create a new figure and axis for the plot
    fig, ax = plt.subplots(figsize=(10, 10) if not show else (8, 8))

    # create a scatter plot of the embeddings, colored by the labels
    scatter = ax.scatter(
        x=points[:, 0],
        y=points[:, 1],
        s=2.0,
        c=label_idcs,
        cmap="tab10",
        alpha=0.9,
        zorder=2,
    )

    # remove the top and right spines from the plot
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    # add a colorbar to the plot
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.ax.set_ylabel("Labels", rotation=270, labelpad=20)

    # if show is True, display the plot
    if show:
        # add a grid to the plot
        ax.grid(True, color="lightgray", alpha=1.0, zorder=0)
        plt.show()
    # otherwise, save the plot to a file and close the figure
    else:
        plt.savefig(config.LATENT_SPACE_PLOT, bbox_inches="tight")
        plt.close()

The plot_latent_space function is used to visualize the embeddings produced by the encoder in a 2D scatter plot. Each point in the plot corresponds to an image, and the point’s color indicates the image’s label. It allows you to visualize how well the encoder has learned to distinguish different classes of images based on their embeddings. If the encoder has learned well, images of the same class should have similar embeddings and thus be close to each other in the scatter plot.

On Line 208, we first use the get_test_embeddings function to generate and collect the embeddings for all the images in the test dataset.

Then, on Line 211, we create a new figure and axis for the plot.

We create a scatter plot of the embeddings on Lines 214-222. The x and y coordinates of the points are the two dimensions of the embeddings. The color of each point is determined by the label of the corresponding image.

On Lines 225 and 226, we remove the top and right spines from the plot. Then, we add a colorbar to the plot on Lines 229 and 230.

Finally, on Lines 233-240, if show is True, it displays the plot; otherwise, it saves the plot to a file and closes the figure.


Getting Random Test Images Embeddings

def get_random_test_images_embeddings(test_loader, encoder, imgs_visualize=5000):
    # get all the images and labels from the test loader
    all_images, all_labels = [], []
    for batch in test_loader:
        images_batch, labels_batch = batch
        all_images.append(images_batch)
        all_labels.append(labels_batch)

    # concatenate all the images and labels into a single tensor
    all_images = torch.cat(all_images, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # randomly select a subset of the images and labels to visualize
    index = np.random.choice(range(len(all_images)), imgs_visualize)
    images = all_images[index]
    labels = all_labels[index]

    # get the embeddings for all the test images
    points, _ = get_test_embeddings(test_loader, encoder)

    # select the embeddings corresponding to the randomly selected images
    embeddings = points[index]

    # return the randomly selected images, their labels, and their embeddings
    return images, labels, embeddings

The get_random_test_images_embeddings function extracts a random subset of images, their labels, and their embeddings from the test dataset. This function is useful for visualizing a subset of the images in the latent space, which can help you understand how the encoder maps images to embeddings.

We start by looping over the batch in the test dataset, and append the images and labels of each batch to the all_images and all_labels lists, respectively, on Lines 245-249.

Then, on Lines 252 and 253, we concatenate all the images and labels into a single torch tensor along the batch dimension.

We then randomly select a subset of the images and labels on Lines 256-258. The imgs_visualize parameter specifies the number of images to select.

On Line 261, we get the embeddings for all the test images using the get_test_embeddings function. We select the embeddings corresponding to the randomly selected images on Line 264.

Finally, on Line 267, we return the randomly selected images, their labels, and their embeddings.


Visualizing Image Grid on Embeddings

def plot_image_grid_on_embeddings(
    test_loader, encoder, decoder, grid_size=15, figsize=12, show=True
):
    # get a random subset of test images
    # and their corresponding embeddings and labels
    _, labels, embeddings = get_random_test_images_embeddings(test_loader, encoder)

    # create a single figure for the plot
    fig, ax = plt.subplots(figsize=(figsize, figsize))

    # define a custom color map with discrete colors for each unique label
    unique_labels = np.unique(labels)
    num_classes = len(unique_labels)
    cmap = cm.get_cmap("rainbow", num_classes)
    bounds = np.linspace(0, num_classes, num_classes + 1)
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # Plot the scatter plot of the embeddings colored by label
    scatter = ax.scatter(
        embeddings[:, 0],
        embeddings[:, 1],
        cmap=cmap,
        c=labels,
        norm=norm,
        alpha=0.8,
        s=300,
    )

    # Create the colorbar with discrete ticks corresponding to unique labels
    cb = plt.colorbar(scatter, ticks=range(num_classes), spacing="proportional", ax=ax)
    cb.set_ticklabels(unique_labels)

The plot_image_grid_on_embeddings function is essentially:

  • Visualizing how the model (encoder) has learned to represent the images from the test dataset in a lower-dimensional space (latent space)
  • And how the decoder reconstructs these representations back to the original image space.

The function creates a scatter plot of the latent vectors (embeddings) and overlays the reconstructed images on the scatter plot. By visualizing this information, one can better understand the quality of the learned embeddings and the effectiveness of the decoder.

Let’s now break down the code line-by-line.

On Line 275, we randomly select a subset of images and their labels from the test dataset. We generate their corresponding embeddings (latent vectors) using the provided encoder. This is done with the help of the get_random_test_images_embeddings function.

Next, on Line 278, a matplotlib figure and axes are initialized with the desired size.

Then, from Lines 281-285, a colormap is created to provide a unique color to each unique label in the subset of images. The color map is of type rainbow and is discretized into several slots equal to the number of unique labels.

A scatter plot is created with the embeddings as points on Lines 288-296. The color of each point is determined by its corresponding label. The colormap made earlier is used for this purpose.

Then, on Lines 299 and 300, a colorbar is added to the plot to show the color-label relationship. Each unique label gets a tick on the colorbar.

    # Create the grid of images to overlay on the scatter plot
    x = np.linspace(embeddings[:, 0].min(), embeddings[:, 0].max(), grid_size)
    y = np.linspace(embeddings[:, 1].max(), embeddings[:, 1].min(), grid_size)
    xv, yv = np.meshgrid(x, y)
    grid = np.column_stack((xv.ravel(), yv.ravel()))

    # convert the numpy array to a PyTorch tensor
    # and get reconstructions from the decoder
    grid_tensor = torch.tensor(grid, dtype=torch.float32)
    reconstructions = decoder(grid_tensor.to(config.DEVICE))

    # overlay the images on the scatter plot
    for i, (grid_point, img) in enumerate(zip(grid, reconstructions)):
        img = img.squeeze().detach().cpu().numpy()
        imagebox = OffsetImage(img, cmap="Greys", zoom=0.5)
        ab = AnnotationBbox(
            imagebox, grid_point, frameon=False, pad=0.0, box_alignment=(0.5, 0.5)
        )
        ax.add_artist(ab)

  
    plt.show()

From Lines 303-311,

  • A grid of linearly separable points is generated in the latent space.
  • The grid covers the range of the scatter plot.
  • This grid is then converted into a tensor and fed into the decoder to generate image reconstructions.
  • Each point in the grid represents a position in the latent space, and the decoder generates an image for each position.

Then, from Lines 314-320, each reconstructed image is overlaid on the scatter plot at its corresponding position in the latent space. This is done by

  • Creating an AnnotationBbox for each image, which contains the image and its position, and adding it to the axes (ax).
  • The OffsetImage class creates an image box that can be added to the AnnotationBbox.
  • The image is scaled down by setting zoom=0.5. The frameon=False and pad=0.0 parameters ensure that the image box has no frame or padding, and box_alignment=(0.5, 0.5) centers the image at its position.

Finally, Line 323 displays the plot using plt.show().


Defining the Network

# import the necessary packages
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

We start by importing the necessary packages, such as numpy for scientific computing, torch for applying the sigmoid activation function, and torch.nn for creating and training an autoencoder network. And finally, torch.nn.functional for applying a ReLU activation in the network. You could even use torch.nn.ReLU() as a replacement.

class Encoder(nn.Module):
    def __init__(self, image_size, channels, embedding_dim):
        super(Encoder, self).__init__()
        # define convolutional layers
        self.conv1 = nn.Conv2d(channels, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)

        # variable to store the shape of the output tensor before flattening
        # the features, it will be used in decoders input while reconstructing
        self.shape_before_flattening = None

        # compute the flattened size after convolutions
        flattened_size = (image_size // 8) * (image_size // 8) * 128
        # define fully connected layer to create embeddings
        self.fc = nn.Linear(flattened_size, embedding_dim)

    def forward(self, x):
        # apply ReLU activations after each convolutional layer
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # store the shape before flattening
        self.shape_before_flattening = x.shape[1:]

        # flatten the tensor
        x = x.view(x.size(0), -1)
        # apply fully connected layer to generate embeddings
        x = self.fc(x)
        return x

As defined above, the Encoder class is a subclass of the PyTorch nn.Module class and defines the encoder part of an autoencoder. The purpose of the encoder is to take an input image and transform it into a lower-dimensional embedding or “code” that represents the essential features of the image.

On Lines 9 and 10, the initialization and super method of the Encoder class is defined:

  • The super function allows this class to inherit methods and attributes from its parent class nn.Module.
  • Three parameters are provided to the initialization method: image_size (the height/width of the input images), channels (the number of color channels in the input images), and embedding_dim (the size of the output embeddings).

Then, from Lines 12-14, three 2D convolutional layers are defined. These layers are used to extract features from the input images. Each convolutional layer halves the height and width of its input due to the stride of 2, while increasing the number of channels.

Line 18 initializes a variable to store the shape of the output tensor before it is flattened. This will be used later to reshape the tensor during the decoding process.

On Lines 21-23, the size of the output tensor, after it is flattened, is computed, and a fully connected (linear) layer is defined. This layer will transform the flattened tensor into the final embedding.

Then on Line 25, a forward method defines the computations the encoder performs on its input. From Lines 27-29, the input is passed through each convolutional layer and then through a ReLU activation function.

Before flattening the tensor, its shape is stored for later use during decoding on Line 32.

Finally, the tensor is flattened into a 1D tensor and passed through the fully connected layer to generate the final embeddings on Lines 35-37.

In conclusion, this Encoder class defines a typical convolutional encoder for an autoencoder. The encoder takes in an image, extracts features using convolutional layers, and then generates a lower-dimensional embedding of the image using a fully connected layer.

class Decoder(nn.Module):
    def __init__(self, embedding_dim, shape_before_flattening, channels):
        super(Decoder, self).__init__()

        # define fully connected layer to unflatten the embeddings
        self.fc = nn.Linear(embedding_dim, np.prod(shape_before_flattening))
        # store the shape before flattening
        self.reshape_dim = shape_before_flattening

        # define transpose convolutional layers
        self.deconv1 = nn.ConvTranspose2d(
            128, 128, kernel_size=3, stride=2, padding=1, output_padding=1
        )
        self.deconv2 = nn.ConvTranspose2d(
            128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
        )
        self.deconv3 = nn.ConvTranspose2d(
            64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
        )
        # define final convolutional layer to generate output image
        self.conv1 = nn.Conv2d(32, channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # apply fully connected layer to unflatten the embeddings
        x = self.fc(x)
        # reshape the tensor to match shape before flattening
        x = x.view(x.size(0), *self.reshape_dim)

        # apply ReLU activations after each transpose convolutional layer
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        # apply sigmoid activation to the final convolutional layer to generate output image
        x = torch.sigmoid(self.conv1(x))
        return x

The Decoder class, similar to the Encoder class, is a subclass of the PyTorch nn.Module class and defines the decoder part of an autoencoder. The purpose of the decoder is to take an encoded lower-dimensional embedding or “code” and transform it back into the original image.

As before, we define the initialization method of the Decoder class on Line 42. The super function is called (on Line 43) to allow this class to inherit methods and attributes from its parent class nn.Module. Three parameters are provided for the initialization method:

  • embedding_dim: the size of the input embeddings
  • shape_before_flattening: the shape of the tensor before it was flattened in the encoder
  • channels: the number of color channels in the output images

On Line 46, a fully connected (linear) layer is defined. This layer will transform the input embeddings into a flattened tensor that has the same size as the tensor before it was flattened in the encoder.

Line 48 stores the shape before flattening it for later use in reshaping the tensor.

From Lines 51-59, three 2D transposed convolutional layers (also known as deconvolutional layers) are defined. These layers increase the tensor’s spatial dimensions (height and width) and decrease the number of channels.

Finally, in the __init__ method, a convolutional layer is defined on Line 61. This layer is used to generate the output image from the upsampled tensor.

Moving on to the forward method on Line 63, it defines the computations that the decoder performs on its input.

On Lines 65-67, the input is passed through the fully connected layer and then reshaped to match the tensor’s shape before it is flattened in the encoder.

The reshaped tensor is then passed through each transposed convolutional layer, followed by a ReLU activation function on Lines 70-72.

Finally, on Line 74, the tensor is passed through the final convolutional layer, and a sigmoid activation function is applied to generate the output image. The sigmoid function is used here because it squashes its input into the range [0, 1], which is the desired range for the pixel intensities of the output image.


Training the Autoencoder

# USAGE
# python train.py

# import the necessary packages
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm

from pyimagesearch import config, network, utils

We start by importing the necessary libraries. torch is the main library that provides multi-dimensional arrays (tensors) and various methods to manipulate them. torchvision is used to load and transform the data, and tqdm is used for displaying progress bars. The modules from the pyimagesearch package (i.e., config, network, and utils) are also imported.

# define the transformation to be applied to the data
transform = transforms.Compose([transforms.Pad(padding=2), transforms.ToTensor()])

# load the FashionMNIST training data and create a dataloader
trainset = datasets.FashionMNIST("data", train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=config.BATCH_SIZE, shuffle=True
)

# Load the FashionMNIST test data and create a dataloader
testset = datasets.FashionMNIST("data", train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=config.BATCH_SIZE, shuffle=True
)

Next, we set up the data transformations and load the FashionMNIST dataset for both training and testing. The images are padded and converted into PyTorch tensors. The data is loaded in batches defined in the configuration file (config.BATCH_SIZE).


Using Roboflow to Download the Dataset (Optional)

You could also use Roboflow to download the Fashion-MNIST dataset. Roboflow provides a convenient way to download datasets directly via the command line.

Note: Be sure to use your own download link, which contains a private key tied to your Roboflow account. Do not share your private key publicly.

$ mkdir fashion_mnist
$ cd fashion_mnist
$ curl -L -s "YOUR_ROBOFLOW_DOWNLOAD_LINK" > fashion_mnist.zip
$ unzip -q fashion_mni!mkdir fashion_mnist
$ !rm fashion_mnist.zip

Be sure to replace YOUR_ROBOFLOW_DOWNLOAD_LINK with the link you obtain from Roboflow.

To load the Fashion-MNIST dataset, we will use the ImageFolder dataset class from torchvision since the Roboflow method would download the Fashion-MNIST dataset into a directory.

# Define the transformation to be applied to the data
transform = transforms.Compose([
    transforms.Pad(padding=2),
    transforms.ToTensor()
])
# Load the training data
train_dataset = datasets.ImageFolder(root='fashion_mnist/train', transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

# Load the test data
test_dataset = datasets.ImageFolder(root='fashion_mnist/test', transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True)

# create an encoder instance with the specified channels,
# image size, and embedding dimensions
# then move it to the device (CPU or GPU) specified in the config
encoder = Encoder(
    channels=config.CHANNELS,
    image_size=config.IMAGE_SIZE,
    embedding_dim=config.EMBEDDING_DIM,
).to(config.DEVICE)

# pass the dummy input through the encoder and
# get the output (encoded representation)
enc_out = encoder(dummy_input.to(config.DEVICE))

# get the shape of the tensor before it was flattened in the encoder
shape_before_flattening = encoder.shape_before_flattening

# create a decoder instance with the specified embedding dimensions,
# shape before flattening, and channels
# then move it to the device (CPU or GPU) specified in the config
decoder = Decoder(config.EMBEDDING_DIM, shape_before_flattening, config.CHANNELS).to(
    config.DEVICE
)

# instantiate loss, optimizer, and scheduler
criterion = nn.BCELoss()
optimizer = optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()), lr=config.LR
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=config.PATIENCE, verbose=True
)

From Lines 33-51, we create instances of the Encoder and Decoder classes defined in the network module. The model parameters like the number of channels, image size, and embedding dimension are taken from the configuration file. The encoder and decoder are moved to the device specified in the configuration file (config.DEVICE).

Then, the loss function, optimizer, and learning rate scheduler are set up from Lines 54-60. The loss function is Binary Cross-Entropy (nn.BCELoss), suitable for binary classification problems and is often used in autoencoders. The optimizer is Adam, and the learning rate scheduler is ReduceLROnPlateau.

The ReduceLROnPlateau scheduler adjusts the learning rate based on a metric. In this case, it reduces the learning rate when a metric has stopped improving. The metric is minimized (mode='min'), the learning rate is multiplied by factor (0.1 in this case) when the metric has stopped improving, and patience is the number of epochs with no improvement, after which the learning rate will be reduced.

# call the 'display_random_images' function from the 'utils' module to display
# and save random reconstructed images from the test data
# before the autoencoder training
utils.display_random_images(
    test_loader,
    encoder,
    decoder,
    title_recon="Reconstructed Before Training",
    title_real="Real Test Images",
    file_recon=config.FILE_RECON_BEFORE_TRAINING,
    file_real=config.FILE_REAL_BEFORE_TRAINING,
)

This display_random_images utility function displays a set of random images from the test dataset before training the autoencoder. It shows both the original and the corresponding reconstructed images before training the autoencoder, as shown in Figure 3 and Figure 4.

Figure 3: Reconstruction by the autoencoder on the test images before training (source: image by the author).
Figure 4: Real test images fed to the autoencoder before training (source: image by the author).
# initialize the best validation loss as infinity
best_val_loss = float("inf")

# start training by looping over the number of epochs
for epoch in range(config.EPOCHS):
    print(f"Epoch: {epoch + 1}/{config.EPOCHS}")
    # set the encoder and decoder models to training mode
    encoder.train()
    decoder.train()

    # initialize running loss as 0
    running_loss = 0.0

    # loop over the batches of the training dataset
    for batch_idx, (data, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
        # move the data to the device (GPU or CPU)
        data = data.to(config.DEVICE)
        # reset the gradients of the optimizer
        optimizer.zero_grad()

        # forward pass: encode the data and decode the encoded representation
        encoded = encoder(data)
        decoded = decoder(encoded)

        # compute the reconstruction loss between the decoded output and
        # the original data
        loss = criterion(decoded, data)

        # backward pass: compute the gradients
        loss.backward()
        # update the model weights
        optimizer.step()

        # accumulate the loss for the current batch
        running_loss += loss.item()

On Line 76, best_val_loss is initialized to infinity. This variable will keep track of the model that gives the smallest validation loss across all epochs.

The training process starts from Line 79, looping over the number of epochs specified in config.EPOCHS. encoder.train() and decoder.train() set the encoder and decoder in training mode, which is necessary because some layers (e.g., Dropout and Batch Normalization) behave differently during training and testing on Lines 82 and 83.

Before each epoch, the running_loss is reset to 0.0 on Line 86. This variable accumulates the loss over each batch within the current epoch.

Line 89 starts the loop over each batch in the training dataset. data represents the input data for the current batch, which is moved to the device specified in config.DEVICE on Line 91 (either a GPU or CPU). On Line 93, optimizer.zero_grad() resets the gradients to zero before starting to do backpropagation because PyTorch accumulates the gradients on subsequent backward passes.

Lines 96 and 97 are the forward pass of encoder and decoder. The input data is passed through the encoder to generate an encoded representation, which is then passed through the decoder to produce the reconstructed output.

Then, on Line 101, the reconstruction loss between the original data and the decoded output is then computed using the Binary Cross-Entropy (BCE) loss function specified in criterion.

Line 104 is the backward pass. The backward() function computes the gradient of the loss with respect to the model parameters, and optimizer.step() on Line 106 updates the model parameters based on the computed gradients.

The loss of the current batch (converted to a Python float using item()) is added to running_loss to accumulate the loss over the entire epoch on Line 109.

This process is repeated for each batch in the dataset and each epoch. This will train the autoencoder model by iteratively improving its ability to reconstruct the input data.

    # compute the average training loss for the epoch
    train_loss = running_loss / len(train_loader)

    # compute the validation loss
    val_loss = utils.validate(encoder, decoder, test_loader, criterion)

    # print training and validation loss for current epoch
    print(
        f"Epoch {epoch + 1} | Train Loss: {train_loss:.4f} "
        f"| Val Loss: {val_loss:.4f}"
    )

    # save best model weights based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(
            {"encoder": encoder.state_dict(), "decoder": decoder.state_dict()},
            config.MODEL_WEIGHTS_PATH,
        )

    # adjust learning rate based on the validation loss
    scheduler.step(val_loss)

    # save validation output reconstruction for the current epoch
    utils.display_random_images(
        data_loader=test_loader,
        encoder=encoder,
        decoder=decoder,
        file_recon=os.path.join(
            config.training_progress_dir, f"epoch{epoch + 1}_test_recon.png"
        ),
        display_real=False,
    )

print("Training finished!")

Line 112 computes the average training loss for the current epoch by dividing the total accumulated loss (running_loss) by the number of batches in the training dataset (len(train_loader)).

The utils.validate() function is used to compute the validation loss on Line 115. The encoder and decoder models, the validation data loader (test_loader), and the loss function (criterion) are passed as arguments.

Then, on Lines 124-129, we check if the validation loss for the current epoch is less than the best validation loss seen so far (best_val_loss), then the current model’s weights are saved. The encoder and decoder state dictionaries (which include the model parameters) are saved in the file specified by config.MODEL_WEIGHTS_PATH.

Line 132 adjusts the learning rate based on the validation loss. The ReduceLROnPlateau scheduler multiplies the learning rate by a factor (default 0.1) whenever the validation loss does not decrease for a specified number of epochs (referred to as patience).

Finally, at the end of each epoch, a set of random images from the validation data is passed through the encoder-decoder pipeline, and the reconstructed images are displayed on Lines 135-143. This provides a visual check on how the model improves its reconstruction ability as it is trained.

The loop continues for the specified number of epochs.


Post-Training Analysis of Autoencoder

# USAGE
# python test.py

# import the necessary packages
import logging
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm

from pyimagesearch import config, utils
from pyimagesearch.network import Decoder, Encoder

# set up logging configuration
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - " "%(levelname)s - %(message)s"
)

We start by importing the necessary packages,

  • logging to record events or issues that occur when running the code.
  • os module that provides a way to use system-dependent functionality like reading/writing to the file system.
  • torch, torch.nn, torch.optim: These are parts of the PyTorch library. torch is the main PyTorch library, torch.nn provides classes for building neural networks, and torch.optim provides classes for optimization algorithms (e.g., SGD, Adam, etc.).
  • torchvision: It provides access to popular datasets, model architectures, and image transformations for computer vision.
  • tqdm: This is a third-party library for creating console progress bars.
  • pyimagesearch: is our custom module containing the project’s utility functions, network definition, and configuration variables.
  • Decoder, Encoder: These are imported from the network module in pyimagesearch and represent the architecture of an autoencoder model’s Encoder and Decoder parts.

On Lines 18-20, the logging.basicConfig() function configures the logging system. The level argument is set to logging.INFO, which means the logger will handle all messages with a severity level of INFO and above. The format argument specifies the format of the log messages. In this case, each message will include the time the log was created, the severity level, and the actual log message.

# generate a random input tensor with the same shape as the input images
# (1: batch size, config.CHANNELS: number of channels,
# config.IMAGE_SIZE: height, config.IMAGE_SIZE: width)
dummy_input = torch.randn(1, config.CHANNELS, config.IMAGE_SIZE, config.IMAGE_SIZE)

# create an encoder instance with the specified channels,
# image size, and embedding dimensions
# then move it to the device (CPU or GPU) specified in the config
encoder = Encoder(
    channels=config.CHANNELS,
    image_size=config.IMAGE_SIZE,
    embedding_dim=config.EMBEDDING_DIM,
).to(config.DEVICE)

# pass the dummy input through the encoder and
# get the output (encoded representation)
enc_out = encoder(dummy_input.to(config.DEVICE))

# get the shape of the tensor before it was flattened in the encoder
shape_before_flattening = encoder.shape_before_flattening

# create a decoder instance with the specified embedding dimensions,
# shape before flattening, and channels
# then move it to the device (CPU or GPU) specified in the config
decoder = Decoder(config.EMBEDDING_DIM, shape_before_flattening, config.CHANNELS).to(
    config.DEVICE
)

# load the saved state dictionaries for the encoder and decoder
checkpoint = torch.load(config.MODEL_WEIGHTS_PATH)
encoder.load_state_dict(checkpoint["encoder"])
decoder.load_state_dict(checkpoint["decoder"])

# set the models to evaluation mode
encoder.eval()
decoder.eval()

Line 25 creates a random tensor with the same shape as the input images. This tensor is useful for fetching the tensor’s shape before it is flattened in the encoder and would be required to be passed to the Decoder model.

From Lines 30-34, an instance of the Encoder class is created. The parameters for the class (number of channels, image size, and embedding dimensions) are provided in the configuration file. After creating the instance, it’s moved to the device specified in the config file (either a CPU or a GPU) using the .to() method.

Next, on Line 38, the dummy_input tensor is passed through the encoder to get an encoded representation. It’s moved to the specified device before it’s passed to the encoder.

Line 41 retrieves the shape of the tensor before it is flattened in the encoder. This is necessary for the decoder to reshape the tensor during the decoding process correctly.

From Lines 46-48, an instance of the Decoder class is created. The parameters are the embedding dimensions, the shape before flattening, and the number of channels. Like the encoder, the decoder is also moved to the specified device.

Lines 51-53 load the saved state dictionaries (which contain the trained weights) for the encoder and decoder from the path specified in the config file.

Lastly, on Lines 56 and 57, the models are set to evaluation mode with the .eval() method. This is necessary because certain layers, like dropout and batch normalization, behave differently during training and evaluation.

# define the transformation to be applied to the data
transform = transforms.Compose([transforms.Pad(padding=2), transforms.ToTensor()])

# load the test data
testset = datasets.FashionMNIST("data", train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=config.BATCH_SIZE, shuffle=True
)

Line 60 defines a series of transformations that will be applied to the images in the dataset. The transforms.Compose function combines all transforms provided in the list. In this case, it applies a padding of 2 pixels around the image and then converts it to a PyTorch tensor.

Then on Lines 63, we load the FashionMNIST test dataset from PyTorch’s dataset library. train=False specifies that you want to load the test set. If the dataset doesn’t exist in the local directory (“data” in this case), it’ll be downloaded automatically due to download=True. When loaded, the transform=transform applies the defined transformations to the data.

Lines 64-66 wrap the dataset in a DataLoader, allowing easy iteration over the dataset and providing many other features. The batch_size is set according to the configuration file, and shuffle=True means data will be shuffled at every epoch.

logging.info("Creating and Saving Reconstructed Images with Trained Autoencoder")
# call the 'display_random_images' function from the 'utils' module to display
# and save random reconstructed images from the test data
# after the autoencoder training
utils.display_random_images(
    test_loader,
    encoder,
    decoder,
    title_recon="Reconstructed After Training",
    title_real="Real Test Images After Training",
    file_recon=config.FILE_RECON_AFTER_TRAINING,
    file_real=config.FILE_REAL_AFTER_TRAINING,
)

Here, we generate and save a few randomly selected images from the test set, along with their corresponding reconstructions, by the trained autoencoder.

Line 68 logs the start of the image generation process.

The display_random_images function from the utils module is called on Lines 72-80.

  • This function randomly selects a batch of images from the test data loader, passes them through the encoder and decoder to generate the reconstructed images, and then saves both the original and reconstructed images.
  • The saved images’ specific titles and file paths are provided as parameters.
  • The title_recon and title_real parameters specify the titles for the reconstructed and real images, respectively.
  • The file_recon and file_real parameters specify the file paths where the images will be saved.

Figure 5 presents the reconstructions obtained from the autoencoder after it was trained on the test images depicted in Figure 6. The reconstructed images appear impressively well-rendered, indicating that the autoencoder manages to do a reasonably good job of reconstructing the input image. However, upon closer examination of both figures, it’s amusing that the model misinterpreted a trouser as a bag and a t-shirt 😀!

Figure 5: Reconstruction by the autoencoder on the test images after training (source: image by the author).
Figure 6: Real test images fed to the autoencoder after training (source: image by the author).

Visualize the Latent Space of the Trained Encoder

logging.info("Creating and Saving the Latent Space Plot of Trained Autoencoder")
# call the 'plot_latent_space' function from the 'utils' module to create a 2D
# scatter plot of the latent space representations of the test data
utils.plot_latent_space(test_loader, encoder, show=False)

Here, we create a 2D scatter plot of the latent space representations of the test data.

The plot_latent_space function from the utils module is called. This function takes the test data loader and the trained encoder as inputs. It then passes the test data through the encoder to generate the latent space representations, which are then plotted on a 2D scatter plot.

The show=False parameter indicates that the plot should not be displayed immediately after creation. Instead, the plot will be saved as a file for later viewing. Note that the file path for saving the plot should be specified within the plot_latent_space function.

Figure 7 shows the encoder’s latent space visualization when trained on the Fashion-MNIST dataset. We color each point in the latent space by the corresponding image’s label to produce the visualization below. Now the structure becomes very clear!

Figure 7: Visualization of the Latent Space for the Test Dataset, Colored by Clothing Category (source: image by the author).

The beauty of the autoencoder is that even though the clothing labels were never shown to the model during training, the autoencoder has naturally grouped items that look alike into the same part of the latent space. For example, the orange cloud of points in the top right corner of the latent space are all different images of trousers, and the blue cloud of points toward the center top are all T-shirt/top categories.


Visualization: Sample Uniformly from Latent Space

We can create original images by picking random points within the latent space and employing the decoder to transform these back into pixel or image space.

Figure 8 shows the uniformly sampled embeddings (in blue) in the latent space, with corresponding images generated by the decoder in Figure 9.

Figure 8: A visualization of the latent space is presented, with points that have been uniformly sampled indicated in blue (source: image by the author).

Each blue dot (in Figure 8) corresponds to one of the images shown in Figure 9, with the embedding vector displayed beneath. Observe that some generated items appear more lifelike than others. What could be the reason for this?

Figure 9: Images produced by uniform sampling within the latent space are displayed, accompanied by their corresponding latent space vectors depicted below (source: image by the author).

To address this query, let’s first note some characteristics of the overall distribution of points in the latent space, referring to Figure 7:

  • Certain clothing items occupy a very small region, while others span a much larger area.
  • The distribution is neither symmetrical around the point (0, 0) nor confined. For instance, there are significantly more points with negative y-values than positive ones, and similarly, there are more points with positive x-values than negative ones. Some points even stretch to a y-value of more than -11.
  • There are substantial gaps between colors with scarce points.

These characteristics make the sampling from the latent space quite complex. Later, we will overlay the latent space with images of decoded points on a grid to better understand why this unbounded and asymmetrical latent space can pose challenges.

The outputs depicted below are of subpar quality in terms of their reconstruction. They come across as blurry, pixelated, and poorly formed. For example, the images in the first-row second column, and third-row fourth column, are not only ill-formed but also make it challenging to discern the corresponding reconstructed class from the Fashion-MNIST dataset.

One might attribute this poor reconstruction to the corresponding points in the latent space positioned on the boundary. However, we couldn’t expect a superior-quality reconstruction even if these points were centrally placed within the latent space. This is due to the inherent lack of continuity in the autoencoder’s latent space.


Visualize the Image Grid on Embeddings

logging.info(
    "Finally, Creating and Saving the Linearly Separated Image (Grid) on "
    "Embeddings of Trained Autoencoder"
)
# Call the 'plot_image_grid_on_embeddings' function from the 'utils' module
# to create a grid of images linearly interpolated
# between embedding pairs in the latent space
utils.plot_image_grid_on_embeddings(test_loader, encoder, decoder, show=False)

Finally, we create a grid of linearly interpolated images between embedding pairs in the latent space.

The plot_image_grid_on_embeddings function from the utils module is called.

  • This function inputs the test data loader, the trained encoder, and the trained decoder.
  • It uses the encoder to generate latent space embeddings of the test data.
  • Then, it selects pairs of these embeddings and linearly interpolates between them to create new embeddings.
  • These interpolated embeddings are then passed through the decoder to generate new images.
  • These images are arranged on a grid, each row corresponding to one pair of embeddings.

In Figure 10, we’ve superimposed the latent space with decoded images arranged on a grid, and it’s already apparent that the decoder’s reconstructions are not meeting the desired standard.

Figure 10: A grid consisting of decoded embeddings via a trained decoder model, superimposed with embeddings from the original images in the dataset, is presented, with each class type distinguished by different colors (source: image by the author).

Let’s further analyze the issues and limitations of the autoencoder, as demonstrated in Figure 10:

  • We observe that if we select points linearly in a confined space that we’ve defined, it’s more likely to yield something resembling a sandal (class id 5) or trousers (class id 1) rather than a bag (class id 8). This is because the segment of the latent space dedicated to the sandal (brown, see Figure 7) is larger than that for the bag (light green).
  • Additionally, the question arises as to how we should select a random point in the latent space since the distribution of these points is undefined. Technically, any point on the 2D plane could be a valid choice! There’s no guarantee that points will be centered around (0,0), which poses a challenge when sampling from our latent space.
  • Finally, we notice voids in the latent space where none of the original images are encoded. For instance, large white spaces are visible at the domain’s edge—the autoencoder has no incentive to ensure that points here decode into recognizable clothing items since very few images from the training set are encoded here.

Even central points may not decode into well-formed images (e.g., points where x \geq 5.0 and y \leq -5.0). In these regions, the sampled embeddings decode into an image that isn’t well formed. This happens because the autoencoder isn’t compelled to ensure the space’s continuity. For instance, even though the point (-0, -11.8) might decode to provide a satisfactory sandal image, there’s no mechanism in place to ensure that the point (-1, -11.8) also yields a satisfactory sandal image.

This issue is subtle in two dimensions; the autoencoder only has a small number of dimensions to work with, so it naturally compresses clothing groups together, leaving the space between clothing groups relatively small. However, this problem becomes more glaring as we use more dimensions in the latent space to generate more complex images, like faces. Suppose we allow the autoencoder free rein over how it utilizes the latent space to encode images. In that case, there will be massive gaps between groups of similar points, with no incentive for the intervening space to generate well-formed images.

In the next installment of our autoencoder series, we will explore how variational autoencoders address the above-mentioned challenges.


What's next? I recommend PyImageSearch University.

Course information:
78 total classes • 97+ hours of on-demand code walkthrough videos • Last updated: July 2023
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled

I strongly believe that if you had the right teacher you could master computer vision and deep learning.

Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?

That’s not the case.

All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.

If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.

Inside PyImageSearch University you'll find:

  • 78 courses on essential computer vision, deep learning, and OpenCV topics
  • 78 Certificates of Completion
  • 97+ hours of on-demand video
  • Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
  • Pre-configured Jupyter Notebooks in Google Colab
  • ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
  • ✓ Access to centralized code repos for all 512+ tutorials on PyImageSearch
  • Easy one-click downloads for code, datasets, pre-trained models, etc.
  • Access on mobile, laptop, desktop, etc.

Click here to join PyImageSearch University


Summary

This tutorial focused on the practical aspects of an autoencoder, beginning with an overview of the dataset used, including its class distribution, data preprocessing steps, and the split between training and testing data.

While we set up the environment for implementation, the post details the configuration of prerequisites and defines essential utilities. These include functions for extracting and displaying random images, validating results, obtaining test embeddings, and plotting the latent space.

With everything set up, the post walks through the process of training the autoencoder, highlighting key considerations and potential challenges. Finally, it tests the trained autoencoder with various experiments, demonstrating its effectiveness and limitations in producing reconstructions. The post ends by showing readers how to interpret and utilize the results generated by the autoencoder.


Citation Information

Sharma, A. “Implementing a Convolutional Autoencoder with PyTorch,” PyImageSearch, P. Chugh, A. R. Gosthipaty, S. Huot, K. Kidriavsteva, and R. Raha, eds., 2023, https://pyimg.co/t0noi

@incollection{Sharma_2023_Implementing,
  author = {Aditya Sharma},
  title = {Implementing a Convolutional Autoencoder with PyTorch},
  booktitle = {PyImageSearch},
  editor = {Puneet Chugh and Aritra Roy Gosthipaty and Susan Huot and Kseniia Kidriavsteva and Ritwik Raha},
  year = {2023},
  url = {https://pyimg.co/t0noi},
}

Featured Image

Unleash the potential of computer vision with Roboflow - Free!

  • Step into the realm of the future by signing up or logging into your Roboflow account. Unlock a wealth of innovative dataset libraries and revolutionize your computer vision operations.
  • Jumpstart your journey by choosing from our broad array of datasets, or benefit from PyimageSearch’s comprehensive library, crafted to cater to a wide range of requirements.
  • Transfer your data to Roboflow in any of the 40+ compatible formats. Leverage cutting-edge model architectures for training, and deploy seamlessly across diverse platforms, including API, NVIDIA, browser, iOS, and beyond. Integrate our platform effortlessly with your applications or your favorite third-party tools.
  • Equip yourself with the ability to train a potent computer vision model in a mere afternoon. With a few images, you can import data from any source via API, annotate images using our superior cloud-hosted tool, kickstart model training with a single click, and deploy the model via a hosted API endpoint. Tailor your process by opting for a code-centric approach, leveraging our intuitive, cloud-based UI, or combining both to fit your unique needs.
  • Embark on your journey today with absolutely no credit card required. Step into the future with Roboflow.

Join Roboflow Now


To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!

Download the Source Code and FREE 17-page Resource Guide

Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!

The post Implementing a Convolutional Autoencoder with PyTorch appeared first on PyImageSearch.


July 17, 2023 at 06:30PM
Click here for more details...

=============================
The original post is available in PyImageSearch by Aditya Sharma
this post has been published as it is through automation. Automation script brings all the top bloggers post under a single umbrella.
The purpose of this blog, Follow the top Salesforce bloggers and collect all blogs in a single place through automation.
============================

Salesforce