Idea and motivations
In this blogpost I’m going to share my journey (still in progress) towards the understanding of how Stable Diffusion and the new wave of diffusion-based generative models like DALL-E and imagen work.
Everything started in the context of fastai (un)practical deep learning for coders 2022 course, where we’re rebuilding Stable Diffusion from scratch. Stable diffusion (SD from now on) is the popular new image-generation model, available for anyone to run on their own computer, which can produce amazing images like these ones based on a simple text prompt like “an astronaut riding a horse”:
Jeremy Howard gave a quick overview of how Stable Diffusion works and I’ve always admired his ability to demystify and simplify complex concepts: like he said <<deep learning?! It’s just a bunch of matrix multiplications>>.
So in this post, I’d like to share a simple, intuitive metaphor for understanding a particular question about table Diffusion:
Given that the SD model is able to navigate the manifold of latents to produce a valid image, what does this manifold look like?
What is a “manifold of latents”? That cryptic mathematical phrase is the very essence of the problem. Roughly, a manifold is a mathematical abstraction representing a set of items (like images) connected to each other so that some are closer than others. A”latent” is a kind of short, meaningful summary of a thing. So the manifold of latents of valid images is a kind of connected set of short numeric summaries of valid images. Yes, this is rather abstract!
That is exactly why a concrete metaphor, like the one I present here, is helpful to understand it. And it is valuable to understand if you want to control image generation or to interpolate between them: in other words if you want to navigate the latent space.
I’ll show how I probed this space with different ideas and techniques: I hope this will help you to better understand how the main components of SD works and how to hack them to implement your ideas.
The key metaphor is the idea of an island of real images: this island is the way I’ve imagined the manifold of valid images. If you think about an image like a point in this latent space (the point you get by encoding the image), the island represents the cluster of points that refers to images you can understand (noisy or not), the sea represents the pure noise, the coast of this island is the blurry border between what you can understand or not, while in the peaks of the mountains you’ll find sharp images.
As we know, SD (using UNET) is able to navigate this space gradually removing noise from the deep sea to the highest peaks, so you can think of it like a navigation control, a compass for voyaging in this multi dimensional landscape! I’ve found that this perspective provides a helpful intuition that makes it easier to reason and more fun to find cool new ways to navigate. IE: if we prove that the island exists:
- Adding noise to an image like we do during training, will be like going down the mountain, while inference is just climbing up the particular peak you’ve specified in the prompt.
- Interpolating between real images would be like walking the great wall between different peaks.
This post is about my process to find clues of the existence of this island and the attempt to figure out what its shape is: in other words I wanted to "see" the island!
Going on first, I will summarize some of the basic components of these systems. Then I’ll talk about a more intuitive way of thinking about the system, and share a few experiments that flesh out the intuitive concepts.
The real images island
The idea of The Real Images Island came to mind when I was trying to understand how these models work. I’ve called it an island because I think of it like a cluster of images that we can recognize surrounded by an increasing amount of noise. I can’t say if there is a big archipelago of different classes of images or if they are connected in a single multi dimensional pangea (that would be an interesting direction for a future "expedition").
The peak of sharpness describes how real images are sharp. It expresses how "unstable" the diffusion process is: as soon as we perturb the latent a bit we quickly fall down the cliff. Finally it recalls the idea of how hard it is to get from noise to a real image (as hard as climbing a steep mountain!).
The sea of noise communicates how big is that domain, how "indistinguishable" are the drops of water between each other, how you lose the signal/information the deeper you go down the abyss.
The island represents this middle ground, the border between images we can understand and the noise that gets bigger and bigger the more we dive into the "sea".
A journey from Idea → Theory → Practice
This first section (Idea and Motivations) presented the "idea": to verify that in the space of images numerical representation (latent space), real images are surrounded by an increasing amount of noise once we get away from these "peaks of sharpness".
The second section (Some Background) briefly introduces some of the most relevant theoretical concepts that we’ll use later.
The last section (Let’s get the expedition started) practically dive into the code and my findings.
Some background
In order to be all on the same page, It’s useful to have a quick recap of the core ideas and concepts that are behind this and the others image generative models. This section briefly introduces some key ideas behind Latent Diffusion Models paper (LDM - is the central paper behind Stable Diffusion).
Definitions:
- Latent: is a numerical representation of a signal (ie: an image); it’s usually a multi dimensional vector where each component contributes in describing certain characteristics of the signal itself. Neural networks are a popular way to encode images and create latents. In general we have a good model if similar images lead to similar latents - for example FaceNet, a popular face recognition model, leverages the fact that the latents created from pictures of my face are close to each other while being far from the latents made with images of other people faces. "My understanding of the Manifold Hypothesis" is one of the very best visual explanations of this concept.
- Real / Valid Image: in this context I consider valid an Image that is visually correct and without noise, despite its contents (ie: "a flying donkey" if sharp and well represented is totally a valid image).
- Dimensionality reduction: probably over simplifying, it’s the process to present in a 2d plot a N-Dimensional signal, where N can be a big number (ie: in our case this N will be 4*64*64 = 16384). There are both linear (PCA, ...) and non linear ways (TSNE, UMAP, VAE, ...) to reduce dimensionality; in general the bigger is N, the richer is the information contained, the better is to use a non linear approach. These techniques are very useful to let the data speak by themself: if looking at these 2d pictures some pattern arises it usually means that we’re in the right direction.
Towards Stable Diffusion
Once I’ve started looking under the hood of stable diffusion I’ve seen some components that didn’t seem like classic implementations of VAE and U-NET at first sight (ie: VAE can have classic KL regularization or a VQ-GAN inspired one, U-NET has attention and conditioning modules).
So starting from Stable Diffusion paper (LDM - Latent Diffusion Models), I’ve read back what I understood was the core-ideas papers and I’ve divided them into two main research pillars:
- diffusion models: this line of research mainly influenced the design of the U-NET.
- latents encoding: this instead shaped the VAE and enabled the whole pipeline to be faster and to work on consumer hardware.
The following picture tries to summarize what were the paper streams that eventually led to SD. An interesting thing that I’ve discovered is that the LDM paper has been made by CompVis, the same research group (and part of the authors) that a couple of years before published VQ-GAN.
AutoEncoder quick overview
AE - Auto Encoders:
[Encoder→Latent→Decoder]
I’ve always been fascinated by Auto Encoders: they are able to "compress" the input signal into a latent representation, extracting patterns and structure in a totally unsupervised way. The only problem is that they don’t always work, and even if they work a vanilla auto-encoder tends to just memorize (overfit) input data creating a latent space a pois, where good signal exists only around a valid sample encoding and weird things happens interpolating between different samples.
VAE - Variational Auto Encoders:
[Encoder→LatentDistribution→Sample→Decoder]
Variational Auto Encoders (VAE) approach try to create a more smooth latent space, where interpolating between samples leads to more reasonable intermediate states. It does so thanks to KL (Kullback-Leibler) divergence regularization, and this is important here because KL-regularized VAE is one of the encoders proposed by the LDM paper (actually KL-reg seems to be the most common and the one I’ll use in my experiments - the other one they suggest is a VQ-reg one we’ll see later). Finally, another key insight about VAE is that the encoder predicts a per-pixel distribution of latents instead of a single latent value (we’ll see this later in the code).
VQ-VAE - Vector Quantized Variational Auto Encoders:
[Encoder->LatentDistribution->Quantization->Decoder]
As an alternative to classic KL-reg VAE, Latent diffusion model (LDM - main paper behind Stable Diffusion) suggests a VQ-regularized VAE.
VQ-VAE stands for VectorQuantized-VAE and is a very interesting technique, the same used by original DALL-E architecture and by VQ-GAN paper (a 2021 paper made by CompVis group that shares some of the main authors with LDM paper).
The core idea behind VQ-VAE is to take the continuous output of the encoder and quantize it by converting into the index of the closest embedding inside a vocabulary (they call it codebook) of possible states; these states are trained together with the encoder.
So the quantization transforms a continuous signal into a 1D-flattened sequence of discrete ones that becomes a perfect sequence input for a transformer; the role of the transformer is to create "semantic awareness" in the sequence and is useful for example to condition, change or restore missing parts of the image. In other words we transform an image into a sentence with a fixed number of words; during training we concurrently find these words using discretization and assign them a meaning by training their embeddings.
It’s important to notice that the output of a VQ-VAE encoder is continuous (quantization happens after that) because the VQ-reg embodiment of LDM uses that kind of encoder.
If you want to dive more on this topic I suggest this great article: https://ml.berkeley.edu/blog/posts/vq-vae.
Stable diffusion quick recap
Long story short Stable diffusion is based on Latent Diffusion Model (LDM) paper. It has three main components:
- VAE TO SIMPLIFY: as we’ve seen in the previous section this step is one of the main contributions of LDM. Reducing input dimensionality leads to better performance and less memory optimization with very small quality loss (we’ll see that VAE can encode/decode an image with differences you can’t appreciate with the naked eye). Moreover, the choice to have latents in the form of pseudo images (64x64 by 4 channels - see the previous section image of the encoded parrot) enables the use of convolutional backbone in UNET and simplifies inspection and conditioning. It’s important to notice that all the diffusion process will happen in the latent space, so think of the latent space as the space where we can move for our search.
- UNET TO CONDITION AND REMOVE NOISE: this is the step that does the "diffusion" - LDM uses a convolutional UNET to predict the noise inside the latent. It optionally accepts as input even a conditioning signal that, using a classifier-free guidance approach, pushes the diffusion process towards a given direction (you can control the strength of your conditioning signal using the guidance scale). This noise removal process is repeated multiple times until we obtain a sharp image. Think of it as the navigation system, the compass that takes you from where you are (a potential initial seed) to the destination you’ve specified (conditioning signal).
- CLIP AS CONDITIONING: Stable Diffusion in particular chooses to use a pre-trained CLIP to condition the UNET. Think of the prompt as the direction where you want to go.
To summarize, Stable Diffusion takes as input a text prompt and optionally a seed/starting image (if we don’t have it, it will start from pure noise on latent space); then iteratively it removes noise from the input latent trying to go in the direction of the prompt you’ve specified. After a certain amount of iterations the latent obtained is decoded back by VAE into an RGB image.
Let’s find the island
Starting from here, we’ll follow the flow of the notebook end-to-end. In my journey I’ve discovered a lot of interesting facts about how Stable Diffusion works: hopefully this will help you too.
Do you want to search the island by yourself?
You can find the code for all this section here: searching_the_real_images_island.ipynb
The code is everything but optimized (and free from bugs). I’ve successfully tested it with an Nvidia RTX-3090; to run it end-to-end it takes about 30 min, mainly because of TSNE/UMAP.
NOTE: the notebook has been heavily inspired by Jonathan Whitaker fastai’s Stable Diffusion Deep Dive notebook:
Exploring “VAE” space
At first I didn’t trusted the VAE; so to prove that it works I’ve encoded and decoded an image and took the difference between the two (aka the reconstruction error): it turned out that this difference is so small that can be discarded; obviously we expect this from a well trained model, but looking at the numbers is pretty impressive.
Thinking about it, an interesting direction for research is to assess VAE performance on images that are probably outside the domain where it has been trained like pixel-perfect 2d geometric patterns or charts. I leave this to the reader or to a future "expedition".
Making this first experiment, I’ve realized that as we saw in the theory, VAE encoders predict a per-pixel distribution that we need to sample in order to have a valid "latent" (4x64x64); so if the input is a RGB image 512x512, the output of the encoder is composed by two tensors mean (4x64x64) and standard deviation (4x512x512).
The decoding step instead transforms back a latent vector to a regular RGB image 512x512 (that is a tensor 3x512x512).
Can we distinguish noise from real images looking at VAE latents?
Originally this was the first question I’ve tried to answer and the steps I’ve followed are:
- create a synthetic latents dataset with three categories: real photos, pure noise and a mix of the two.
- use dimensionality reduction techniques to see if a pattern emerges.
The pure noise instead required a bit of care because we want that our noise seems coming from one of the photos in order to make it harder to distinguish the two; moreover what we really want is noise on latent space not on RGB space. The picture below shows the difference between an image made of noise and an image produced decoding a latent made of noise: as you can see there is a lot of difference. So I’ve computed mean and std from a batch of real photos latents and used to sample the noise, producing "noisy latents" that have the same statistics of the one produced by real images.
Mixing noise and real photos latents was another easy part: I’ve just linearly interpolated them.
Another very important detail is that on each step we’ve used different images: so the noisy images are different images from the real ones.
Once created the latents dataset, the next step was to analyze it, with two popular dimensionality reduction techniques: TSNE and UMAP. Surprisingly both of them lead to good separation between pure noise and the rest (at least for the vast majority of the samples), but in both cases looking at the charts, it’s hard to tell the difference between real images and noisy images.
The second experiment was focused on the noise level. This is a key concept on Denoising Diffusion Probabilistic Models paper (DDPM) that calls it "t" (time step) and it’s used both during training and inference.
To do so I’ve used the same approach of the previous experiment, but with a different latents dataset and a slightly different way to create noisy samples. The new dataset is designed to highlight what happens to an image when we add noise and it considers the fact that there are a lot of different ways (noise paths) to add noise to the same image. The dataset has been created using:
- 80 different real images (seeds)
- 5 noise steps ranging from 0 to 30 for each image
- 2 variants for each noise step
NOTE: the way noise samples are created now is slightly different: previously they were created trying to mimic stats (mean, std) from a latent coming from a real image; now instead the noise part is created as zero mean and std controls the amount of noise added. I’ve done this because the noisy samples are computed summing in the latent domain this noise component to a real image (so having a zero mean noise should avoid biasing the result - visual results obtained decoding these latents seems to confirm it’s the right direction).
For each real image we’ll have 31 different latents versions (one real and 30 noisy with different levels) that we can group in a total of 16 possible noise paths.
So in total we have 16*80=1280 possible paths and a total of 31*80=2480 different latents.
The reason behind the choice of this hierarchical noise schema is to try to imbalance the dataset and incorporate the fact that our valid images manifold is a "small" subset of the possible latent space (this intuition comes from the manifold hypothesis and from the fact that for each valid image there are a lot of different way to obtain a noisy one).
Even this time, after dimensionality reduction we didn’t see any clue of ordering by noise level, but interestingly both of the plots show the fact that latents derived from the same "real image" (seed) are close to each others; this is even more evident for TSNE, where surprisingly the seed (in black) is always at the center of the micro-cluster containing all the 31 variants.
I’m pretty sure this is related to the fact that we’re constructing noisy samples adding zero mean noise with increasing std to the same image, but this is similar to how the diffusion process works. After a certain amount of noise you can’t visually tell the difference between two samples coming from different seed images, but probably looking at their mean we could separate them given that we’re adding zero-mean noise that slightly affects the final image mean (I suspect that this is affecting the emergent grouping pattern).
In general, if TSNE can figure out that all of those images are a noisy version of a given seed, this is a good sign and there should probably be a way to remove that noise.
As we saw in the previous section looking at how Stable Diffusion works, the role of UNET is to predict the noise** **in the input latent, so somewhere in the UNET model there should be the ability to assess how much noise we have on the input latent. To do so we’ll analyze with dimensionality reduction the latents in the UNET bottleneck (aka: the UNET latents).
To recap:
- We’ll take the latents from the previous dataset (the one containing noise paths) and encode them with the UNET encoder to transform them into UNET latents.
- We’ll analyze them with TSNE and UMAP to see if any pattern arises.
I had a hard time extracting UNET latents, but I’ve found a way to monkey patch the UNET encoder (UNet2DConditionModel) and let it save the output that I wanted. I should absolutely thank Pedro again for pointing me to the right part in diffusers code - that saved me a lot of reverse engineering effort.
Another important thing to consider is that we don’t want our UNET latents to be affected by any prompt, so I’ve chosen to pass an "unconditioned prompt" (aka: an empty string as the one we pass usually to leverage the classifier free guidance). Moreover all the latents has been encoded using a timestep set to zero: the purpose of this is to avoid giving any clue to the encoder about the noise level, despite we’re probably biasing it’s output.
A fun fact that I’ve discovered is that the dimensionality of UNET-LATENTS is bigger than the one of the input VAE-LATENTS summed by the dimensionality of the conditioning signal:
- dimensionality of VAE latent 16384
- dimensionality of text embedding latent 59136
- dimensionality of UNET input 75520
- dimensionality of UNET sample encoded 81920
NOTE: conditioning signal (aka the clip encoded prompt) has a higher dimensionality than the VAE-LATENT.
For dimensionality reduction, I’ve chosen to focus on UMAP because using the n_neighbors (number of neighbors) it’s possible to control the global structure (relation with respect to all points), TSNE instead tends to favor local structure (relation with the close neighbors points). If you want to dive more on the topic of TSNE/UMAP differences regarding to local/global structure I suggest to take a look at this: https://towardsdatascience.com/tsne-vs-umap-global-structure-4d8045acba17
The first plot was really promising: all points with the same color lead to the same seed image (unfortunately this is not completely true because we have 80 seeds and 20 colors in this colormap, so we’ll have 4 different samples with the same color, but for this initial test it’s ok). The big circle represents the real image (peak of sharpness), and for some of them I’ve plotted all the noise paths.
Here you can see the two versions of the same plot: the position is the same for both of them and computed reducing the dimensionality to 2d with UMAP. One color the dots with the modified terrain color map and the other shows the image behind the latent (aka: the image is created decoding back to RGB each original VAE-LATENT).
As you can see this time the concept of noise level is clearly visible in the picture from the fact that points with similar levels of noise are close to each other.
Conclusion
The plot depends on carefully chosen UMAP parameters, the shape is not circular and we have a sparse "land" and super high density deep "water", but we’ve found the island.
Why is it important to find the island? Proven that it exists, the metaphor of island can be used to describe and get intuitions about the different operations that we can do like:
- TRAINING UNET: we start from valid points (peaks) and add noise escaping away from the island to the sea.
- INFERENCE GENERATING NEW IMAGES: we start from a point in the sea of noise and we search the highest peak of sharpness in the direction of the prompt.
- INTERPOLATE: we have a nice interpolation between two latents if we walk from peak to peak between the two points.
LINKS:
- Real images island notebook: https://github.com/artste/fastaisf/blob/master/diffusion-nbs/searching_the_real_images_island.ipynb
- VQ-VAE: https://ml.berkeley.edu/blog/posts/vq-vae
- Latents and the manifold Hypothesis: https://www.youtube.com/watch?v=BePQBWPnYuE
- Jonathan Whitaker’s Stable Diffusion Deep Dive: https://github.com/fastai/diffusion-nbs/blob/master/Stable%20Diffusion%20Deep%20Dive.ipynb
- Pedro Cuenca’s notebook: https://github.com/fastai/diffusion-nbs/blob/master/stable_diffusion.ipynb
Credits:
- Zach and Pedro for pushing me into Stable Diffusion since the beginning.
- Delft study group and especially Alex and Suvash for the support and help in making this possible.
- Jeremy, Tanishq and the fastai forum: a continuous source for learning and inspiration.
- Finally Alexis Gallagher for all the insights and pushing (actually kicking) me so much to make this blog post happen.
Future development:
This post has already grown too much and I need to flesh it out, but in the (HARD!) process to make it I came across some interesting ideas I want to pursue in the future:
- Expand dataset with non images: see what happens adding pixel perfect chart plots or solid colors (ie: encode a 512x512 square image filled with the same red pixel).
- Reduce dimensionality with TriMap and PaCMap: these are more modern techniques (especially PaCMAP) that seems to overcome the local vs global structure tradeoff we’ve seen with TSNE and UMAP.
Behind the scene - aka the lost section:
The audience for this section is clearly friends and in general the fastai community, but I want to post it anyway because this is the real story behind this post; I hope this will spark others to follow the community, join study groups and share their experiences.
Since the previous “practical” edition of the 2022 course, I’ve been joining the great Delft study group (thanks again to Alex for setting this up and to Pedro for pushing me in). The study group meets on Sunday morning with the goal of reviewing the topics from the weekly lesson and each one can ask questions and give opinion/understanding about it.
Recently we were discussing images latent space and I was trying to find a metaphor to help me explain the core idea behind diffusion models:
- TRAINING: you train them by adding noise to real images.
- INFERENCE: you’ll ask the model to remove noise, going towards a given “direction” (ie: conditioning prompt).
I’ve explained it by saying that the model (U-NET) acts as a navigation system that drives you from pure noise to a sharp image according to the direction you’ve specified in the prompt (ie: an astronaut riding a horse).
I should thank Suvas who said: <<ok Stefano, you can try to present this idea next week>>; without the pressure of “I must deliver something for next sunday”, I would probably have procrastinated it forever and never pushed myself so much into this domain.
At the end of the day, this blog post summarizes what I’ve learned trying to deliver something for next sunday ;-)
Long story short, this didn’t take me a single weekend but “a couple” of them, where I’ve incrementally presented to the team my findings and got from them invaluable feedback and suggestions on how to proceed. Learning by presenting is a powerful tool that forces you to review your understanding and really dive into the concepts.