Linear Diffusion: Building a Diffusion Model from linear Components
It seems like everyone is releasing some cool model right now, so why not join the fun? Here at Count Bayesie we are introducing a new model unlike anything you’ve seen before: Linear Diffusion! A diffusion model that uses only linear models as its components (and currently works best with a very simple “language” and MNIST digits as images).
You can get the model and code on Github right now. If you’d like to learn more about Linear Diffusion, read on!
Here is a basic overview of Linear Diffusion’s architecture:
Why Linear Diffusion you might ask? As you may know, I’m a big believer in the power of linear models, and have many times argued one should never start any ambitious project without first building a linear model to see if you can get some decent performance from that. Only after a simple model shows promise should you add complexity. Surely this wisdom must fail when talking about state of the art models such as Stable Diffusion and Dall-E 2 right!? Is it even possible to have a linear diffusion model?
That’s what I wanted to find out! To make this problem a bit easier for a linear model to figure out I’ve constrained the problem space a bit including:
Statements in the language are limited to just the string representations of the digits 0-9
The images are the MNIST digits
Given these constraints, it turns out we can build such a model and can get some surprisingly good outputs from it! Here’s an example of Linear Diffusion’s responses to various prompts:
While certainly not an “astronaut riding a horse”, I personally think these are pretty impressive results for what is ultimately a linear model. As an added bonus, working through building Linear Diffusion provides a pretty good, high level overview of what real diffusion models are doing!
A Quick Tour Through Diffusion Models
Before we can build our Linear Diffusion model, we first need to understand what diffusion models are doing under the hood. It’s difficult to understate the complexity of these models in practice, so be advised: we are looking at this model from an extremely high level of abstraction. However even this high level of abstraction will allow us to map state of the art diffusion architectures to an implementation using only linear models.
Goal: Generating Images from a Text Description
Let’s be clear about what our final goal is: We want to provide our model with a text description of what we want to see, and have the model return a generated image matching that description. For example we might enter “A salami sandwich” and hope to get an image like this out of Stable Diffusion:
Most diffusion models we’re familiar with allow you to enter arbitrary text, however for our simple model we’ll limit our language to statements representing digits 0-9. So if we enter '“5” we might get an image like this:
Here’s a very quick summary of how these models achieve this goal: diffusion models learn to de-noise a noisy image in a way that agrees with the text describing the image. Then we ask the trained model to de-noise pure noise in a way that agrees with the text description and poof! we have a novel image generated by the model!
Now let’s step through this process a bit more slowly.
A Map for Training and Prediction
It can be useful to visualize our journey, and step through this process using a visualization as our map. Even at a high level of abstraction, diffusion models have a lot of moving parts! Here is a visualization of how training and prediction works for a diffusion model:
With this guide handy let’s step through the process!
We start with training data that includes images and descriptions of those images. The first two parts of this model happen essentially in parallel: text embedding and image encoding
1a. Text Embedding
We see in part 1a that we take our original text description and then create a vector embedding from this. Typically in diffusion models this will involve not only very sophisticated language models such as an LSTM or Transformer, but also involve an additional step of aligning the text embedding with the image encoding /embedding. This way, in the latent text space we’re defining, text that describes visually similar events will near each other.
In our Linear Diffusion we will simply one-hot encode our digit strings, and we won’t worry about making sure these text embeddings are close to our images embeddings.
1b. Image Encoding
At the same time as we do our text embedding, in step 1b we also encode our image. This serves two purposes. First it compresses our image tremendously, making all the models we’ll need later on much smaller (which is good because in practice diffusion models are gigantic as is!) Additionally, as we’ll see when we step through the implementation of linear diffusion, this embedding does help us to capture some information about the space of images in general, driving even random noise in the encoding space to look like real images when decoded. The images projected onto this embedding/encoding space are sometimes referred to as “latents”, since they are a latent representation of the image.
Real diffusion models typically us a Variational Autoencoder (VAE) for this step, Linear Diffusion will use Principal Component Analysis (PCA), a linear projection of our original data onto orthogonal vectors.
1b (continued). Adding noise to our image encoding
This is one of the most conceptually important steps of the process. We’re going to add noise to our image encoding and that same noise will be what we are trying to predict at a later step. The basic idea is: if we can predict the noise, then we can remove it.
For most Diffusion Models this process is fairly complex, involving adding progressively more or less noise to an image according to a training schedule. Often times information about how much noise is added to the embedding is included in the vector fed to the denoising model.
In Linear Diffusion we’ll simply be adding noise by sampling from a standard Normal distribution.
2. Concatenating the embeddings
This is step is fairly straightforward: the input to our denoising model is going to be the combination of both the text embedding and the noisy image embedding. The idea is that the denoising model will use information from the text embeddings to help it figure out where the noise might be in the image embedding.
3. The Denoising Model
This is the part of the diffusion model that really makes a diffusion model. In this stage we’re going to use all the information we have so far to attempt to predict the noise we’ve added to the image encoding. Our target at this stage is the very noise vector we used to add noise to our image encoding.
The denoising model for most Diffusion Models is a Denoising UNET which is a very sophisticated neural network for processing image data. Linear Diffusion will, of course, be using Linear Regression for this stage. However we will be adding interaction terms to give our model a bit of help in this ambitious task.
At this stage we are done with the training of the model. It’s helpful to also understand how we can go from here to predicting the original image. Understanding this will aid in understanding how this model can generate images from nothing.
4. Subtracting the noise
If we want to reconstruct the original image, the first step in the process is to subtract the noise from the noisy image embedding, leaving us with, if we estimated the noise correctly, the original embedding.
Linear Diffusion will just literally subtract the estimated noise from the embedding, but for real diffusion models the process can very closely mirror the training process of a neural net itself. Diffusion models will typically only subtract out a fraction of the estimated noise, run the new slightly denoised embedding back through the denoiser, and then apply that estimated noise. This process will be repeated multiple times until the final image is produced.
5. Finally, decoding the latent, denoised image embedding
The last step is essentially the same for Linear Diffusion and typical diffusion models: we just take our encoder and run it backwards (or we could learn the decoder separate) and project our latent image back into a full sized image!
A fascinating model if you ask me but: how do we generate images from this?
A Map for Image Generation
The visual guide for generating images given a text prompt requires just a few minor modification to our original model:
No need to retread many of the steps here which are the same as before. The key difference between training/prediction and generation is: we no longer use image embeddings but random noise in the embedding space.
I find this to be a pretty clever solution to the problem of creating images from nothing. To generate a novel image we start with just noise and the text embedding of what we want to see. The denoiser will then try to find what we have described in the noise we fed it, and predict what noise it believes is there.
After this we just repeat the process of reconstructing the original image using our denoised noise!
What’s cool about Linear Diffusion is that we aren’t taking any short cuts with this part of the process. We are just giving the model the string “5” and it is taking a random vector in the latent space and using the process outlined above to “denoise the noise”. So even though many of the samples from Linear Diffusion might look just “okay”, they are, in my opinion, very remarkable given how complex the process is that generates them!
Building Linear Diffusion Step by Step
Now let’s walk through building Linear Diffusion piece by piece. Not only will this help us understand the general architecture of diffusion models even more, but it provides some useful insight into why they behave the way they do and potentially improve our understanding of the more sophisticated cases… plus it is cool to see even linear models can product something in this space!
Text “embedding” - One Hot Encoding
This is the most straightforward, and frankly boring, part of our linear model. Our “language” is going to be constrained to just string digits 0-9. We’re going to use a one-hot-encoder to create vector embeddings. Here’s the code:
from sklearn.preprocessing import OneHotEncoder text_encoder = OneHotEncoder(categories='auto', drop='first', sparse=False, handle_unknown='error' ) labels = np.array(all_labels).reshape(-1,1) text_encoder.fit_transform(labels) label_embeddings = text_encoder.fit_transform(labels)
There we go, the simplest possible text encoder for a very simple “language”. Now let’s look at some really neat stuff regarding how we work with image data!
Image encoding/decoding - PCA
We’re going to be using the MNIST digits as our training data since they are both very simple and still contain a fair amount of obvious non-linearity in their nature. We’ll start by taking a look at a sample of these as they appear in the dataset:
Encoding our MNIST digits
We’ll start by focusing on how we can encode our image data into a lower dimensional space. Currently each image is 28x28 pixels or a vector of 784 floats. We’re going to project this down to 12x12 images, or vectors of 144 floats. As mentioned before we’ll be using PCA to create this embedding. It is worth noting that an Autoencoder is a generalization of PCA to allow for non-linear projection. So we’ve not really too far off the VAE that a real diffusion model would use, at least philosophically.
We’re also going to be using SKLearn’s StandardScalar to make sure that all of our values in our image vectors have means of 0 and variance of 1 before running them through PCA (otherwise the first principal component will just be the mean of the features). Here’s the very simple code to build our image encoder:
from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA LATENT_SIZE=12 image_encoder = Pipeline([('scaler', StandardScaler()), ('PCA', PCA(n_components=LATENT_SIZE*LATENT_SIZE))]) X_encode = image_encoder.fit_transform(all_data)
We know can see how those very same images look in their latent space:
Hard to recognize our digits anymore! However these “latents” contain quite a bit of information. We can see exactly how much information these encodings contain by decoding them back into digits!
Decoding our latents back to digits
It is fairly straight forward for us to reverse our encoding and project the latents back into the original images. Of course we will lose a bit of information, but it is surprisingly little. Here is the code to transform these latents back (SKLearn makes this very easy, but the math is pretty straightforward as well):
X_decode = image_encoder.inverse_transform(X_encode) recon_imgs = X_decode.reshape(X_decode.shape[0], all_imgs.shape[1], all_imgs.shape[2])
And here is the result of our reconstruction:
As we can see, the results are surprisingly good even though we remove 80% of the data from our original images (though clearly not 80% of the information)!
Now that we can go “there and back again”, let’s explore how noise behaves in this latent space.
Adding Noise to Our Image Encodings
Recall that the next step in our process is to add noise that we later want to predict. This is very straightforward in code:
rng = np.random.default_rng(1337) noise = rng.normal(loc=0.0, scale=1.0, size=X_encode.shape) X_noisy_encode = X_encode + noise
What’s very interesting is the way this noise in the latent/embedding space impacts the reconstruction. Normally we think of adding noise to an image as adding “snow” but we’ll see that adding this noise to our latent space causes the reconstruction to be more robust than you might expect:
In the above image we just added noise generated with a standard normal. Let’s increase the noise and see what that looks like:
What’s fascinating is that the latent space seems to preserve some information about “digitness” so to speak. It looks more like we’ve smeared these images with our finger than just added noise. To really see how powerful the latent space is itself, let’s look at pure noise reconstructed:
Here we see that embedding these images doesn’t simply save space, the embedding space captures a lot of information about what a digit is. Again, this last set of images is literally just noise, there is no image behind these. Yet, even with this simple, linear encoding, we can see that these do look vaguely like digits.
What this implies is that in this learned latent space, all points are much closer to real images than any arbitrary point in the original 784 value space. Consider that diffusion makes use a much more powerful, non-linear encoder and it’s suddenly less shocking (though still remarkable) that models like Stable Diffusion can generate such convincing results. Even noise in a the latent space resembles known images.
Denoising with Linear Regression (and interaction terms!)
We’re very close to being done with our model! The only big piece of our model left is to predict the noise we added to our image.
However since we just have a linear model we are going to do a tiny bit of feature engineering before we pass it in to a regular ol’ regression model. Our simple model has no chance of learning anything about the noise specific to a text embedding without us helping it make a few connections. Rather than just concatenating our image embedding and our text embedding, we’re going to also add interaction terms.
Our interaction terms are simply going to be multiplying each value in the image embedding by each value in the text embedding to create some new features. Here’s a visualization of how our final feature vector is built:
Here is the code for building these interaction terms:
def create_interaction_terms(text_embeddings, image_embeddings): interactions = [] for i in range(text_embeddings.shape[1]): interactions.append((image_embeddings * np.array([text_embeddings[:,i]]).T)) return np.concatenate(interactions,axis=1)
And we can put all these together to build our final features and keep track of the noise so we can use it as our target:
def create_features_noise(text_embeddings, image_embeddings=None, std=1.0, seed=1337): rng = np.random.default_rng(seed) noise = rng.normal(loc=0.0, scale=std, size=(text_embeddings.shape[0],LATENT_SIZE*LATENT_SIZE)) if image_embeddings is not None: noised_embeddings = image_embeddings + noise else: noised_embeddings = noise interaction_terms = create_interaction_terms(text_embeddings, noised_embeddings) # we need the noise back directly since it will be used as a target in training features = np.concatenate([noised_embeddings, text_embeddings, interaction_terms], axis=1) return features, noise
And now we can finally build our denoiser!
The final model - Linear regression!
Much, much, much simpler than a UNET, here is our final model in code:
X_train, y_noise = create_features_noise(label_embeddings, X_encode) from sklearn.linear_model import LinearRegression model = LinearRegression().fit(X=X_train, y=y_noise)
That’s it, that’s all there is to our very simple Linear Diffusion model!
Generation and Results
If you look carefully at the definition of create_features_noise you’ll notice we don’t have to pass in image embeddings. If we don’t, the function will just use the noise as the embedding. In the final version of Linear Diffusion I’ve refactored all of the above code into a LinearDiffusion class. Here is the predict method which does the actual generation:
def predict(self, X, seed=1337): labels = np.array(X).reshape(-1,1) label_embeddings = self.text_encoder.transform(labels) X_test, noise_test = self._create_features_noise(label_embeddings, seed=seed) est_noise = self.model.predict(X_test) denoised = noise_test - est_noise est_imgs = self.image_encoder.inverse_transform(denoised).reshape( label_embeddings.shape[0], self.image_size, self.image_size) return est_imgs
Here is an example of training and running Linear Diffusion end-to-end:
import mnist import numpy as np from lineardiffusion import LinearDiffusion # setup the training data train_imgs = mnist.train_images() test_imgs = mnist.test_images() all_imgs = np.concatenate([train_imgs, test_imgs]) all_labels = [str(val) for val in np.concatenate([mnist.train_labels(), mnist.test_labels()])] ld = LinearDiffusion() ld.fit(all_labels, all_imgs)
Finally we go ahead and create some examples:
test_labels = list(chain.from_iterable([[str(i)]*5 for i in range(10)])) # here's our prediction! test_images = ld.predict(test_labels,seed=137)
And the results!
While these images are still quite imperfect, considering the remarkable task we’ve asked of these very simple models, I think these results are down right impressive. What we see here is that we have been able to map text embeddings to image embedding in a generative manner. I also suspect that curious minds could find easy improvements in the model that would take the performance to the next level. It also achieves all of this with a few megabytes of data in under a minute!
Conclusion
In an age when it seems like all interesting work in AI is done by large corporations with mindbogglingly expensive compute resources and near all the training data that exists, it’s easy for the individual hacker to get discouraged. For me, this is part of the appeal of simple, linear models. They allow us to focus on sketching out ideas, and are more powerful than most people realize.
It’s tremendous fun to play with models like Stable Diffusion and Dall-E 2, and easier then ever. Yet as individual programmers and data scientists we shouldn’t feel confined to only play with someone else’s models. There is still much to be discovered and explored by combining the engineering oriented thinking around these impressive and complex architectures and even the most basic of modeling tools. Stretch your model thinking, try something ridiculous and the results may surprise you!
I also want to give a shout out to Jeremy Howard and the team at Fast.ai whose Practical Deep Learning for Coders Part 2 was instrumental in helping me to understand diffusion models better and made Linear Diffusion possible. I highly recommend anyone looking to understand diffusion models in more detail (and there’s a lot more detail!) check out that course.
Want To Learn More About Stable Diffusion?
Then check out A Damn Fine Stable Diffusion Book being released by Manning! The preview version of the book is available today!