In this blog post, we will be demonstrating how to use the DiffEdit technique described in this paper, to use a diffusion model to modify just one part of an existing image using simple text prompts. DiffEdit utilises the diffusion model which is used to predict where noise is in an image, typically as a way of generating images using text prompts. In this blog post we assume you have basic knowledge of Stable Diffusion and how it works, as this is the diffusion model that we will be using throughout our examples.

This blog post is based on our notebook on Kaggle. In this post we aim to provide a distilled version of the full code, explaining the important steps while skipping some of the more basic setup code. Check out the notebook if you want to see the full code or are interested in running it or adapting it yourself.

We will work through an example where we use the DiffEdit technique to modify an image of a horse into an image of a zebra. This will be our final result:

png

While many diffusion models can be used in an image to image (img2img) pipeline, it is fairly limited compared to the DiffEdit technique. In img2img, all parts of the image will be changed by this process, you cannot limit the modifications of the image to a specific area. In our image above you can see that the background is unchanged between our original image and the generated zebra image.

The DiffEdit paper utilises the noise prediction which a diffusion model provides to us to automatically limit edits to certain parts of an image. It does this in three main steps which are summarised in the pipeline chart below:

svg

In the sections below we will work through these steps. First we will work through generating that mask and demonstrate a few techniques for improving the mask, then noise our original image, before finally utilising that mask to only edit one part of our original image.

Creating the image mask

The first step of the DiffEdit technique is creating the mask by utilising the difference in noise prediction between two text prompts. The first prompt is the part of the image that you want to replace, and the second prompt should be what you want to replace it with, so in our examples we’re going to set our original prompt to be "horse" and the new prompt to be "zebra".

Let’s take a look at our original image.

png

Getting the difference in noise prediction between the two prompts is achieved by adding 50% noise to our original image, and then giving our diffusion model the two prompts and ask it to predict noise for both "horse" and "zebra". We can then take the difference in the noise predictions to create our mask.

This works because the original image is already an image of a horse, and so our "horse" prompt should create a noise prediction which is evenly distributed. For the "zebra" prompt however, while the background is plausible given the prompt, the horse in the foreground isn’t very zebra-like at all, and so our model will predict that there must be additional noise in this area. This means that if we take the difference between the two predictions (averaged out over 10 times), then we should get a pretty good outline of a horse.

Our diffusion model is unable to work with images directly, so we first have to convert the image into latents using its variational autoencoder (VAE).

original_image_latents = encode(original_image)

We can investigate the shape of our image latents:

> original_image_latents.shape
torch.Size([1, 4, 64, 64])

We can see that the image has been converted into a 1x64x64x4 tensor, which is the latent format that our particular diffusion model can work with. The images below show you what those latents look like over the 4 layers:

png

NOTE: These 4 channels do not directly map to RGBA or any other format. The contents of the latents are completely determined by the VAE, and other VAE models trained in different ways would produce different latents from our original image.

Prompt to Text Embeddings

We have now converted our original image into the correct format for our model, and we also need to do this for the text prompts we will be providing.

To do this we first use a tokenizer to map the text into number values, and then we use an encoder to convert them into tensors in a format that the model can work with. We have stored this behaviour in a handy function so that we can reuse this logic again later for other prompts.

def create_text_embeddings(prompts, batch_size=1):
    text_embeddings = []
    for prompt in prompts:
        text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        text_embeddings += text_encoder(text_input.input_ids.to("cuda"))[0].half().unsqueeze(0)

    # create matching unconditional input
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
    uncond_embeddings = text_encoder(uncond_input.input_ids.to("cuda"))[0].half()

    return torch.cat([uncond_embeddings, *text_embeddings])

To get an idea of what the tokenised text looks like:

> tokenizer(["horse"], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

 {
     'input_ids': tensor([[49406,  4558, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
           49407, 49407, 49407, 49407, 49407, 49407, 49407]]),
      'attention_mask': tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0]])
 }

We can see above that our prompt of "horse" has been converted into different tokens:

  • 49406 is equal to start
  • 4558 is equal to "horse"
  • 49407 is equal to end/padding

The tensor is padded to be a particular length that the model expects, but an attention mask is also passed in which tells the model which tokens it should pay attention to, where 1 means it should pay attention to it and 0 means ignore. In this case it is the first three tokens only.

prompts = [original_prompt, new_prompt]
combined_embeddings = create_text_embeddings(prompts)

Noising the original image

To add noise to our original image to the correct level, we will utilise the scheduler. The scheduler typically manages noise throughout the image creation process for diffusion models, and it can also be used to manually add a specific level of noise to an image.

First we need to set up the scheduler, and provide it with the configuration values that our diffusion model expects. We have used a DDIM scheduler, which is the type of scheduler that is recommended by the DiffEdit paper as it allows for more consistent results between the original image and the edited image.

We’ll define this setup as a function as we’ll be using it again later when we generate the DiffEdit image. The argument num_inference_steps is the total number of steps that our model would take to fully denoise an image. With the version of Stable Diffusion and the scheduler we have chosen, 50 steps usually works well.

def init_scheduler(num_inference_steps):  
    scheduler = DDIMScheduler(**scheduler_config)
    scheduler.set_timesteps(num_inference_steps)
    return scheduler

Next, we will define a generic function which creates noise and adds it to the latents we pass in. It also takes in a start step for the scheduler which allows us to control the amount of noise added to the image.

def create_noise(noise_scheduler, initial_latents, height=512, width=512, batch_size=1, start_step=10, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    initial_noise = torch.randn((batch_size, unet.config.in_channels, height // 8, width // 8)).to("cuda").half()
    
    if start_step > 0:
        return noise_scheduler.add_noise(
            initial_latents,
            initial_noise, 
            timesteps=torch.tensor([noise_scheduler.timesteps[start_step]])
        )
    return initial_noise * noise_scheduler.init_noise_sigma

We will pass in the latents from our encoded original image into our new function, and we want to add 50% noise as suggested in the DiffEdit paper. Our function doesn’t take in an amount of noise specifically, but allows us to specify which step the scheduler should use to noise the image. Halfway through the total number of steps will add 50% noise.

mask_noise_step = num_inference_steps//2
noisy_latents = create_noise(scheduler, original_image_latents, start_step=mask_noise_step)

We can also visualise our noisy latents to see what they now look like with the noise added:

png

It’s also possible to decode the noisy latents using our VAE. There’s a hint of the original image there but certainly a lot of noise added too:

png

Noise Prediction Code

Now that we have created the noisy latents, we need to use our model to predict where the noise is in the image, given the two prompts that we’re working with. The noise prediction code that we will use here is very similar to the code you would use when doing image generation. The main differences are that we are starting with our original image latents with 50% noise, and only doing a single step of noise prediction unlike the 50 we would do when generating an image.

def predict_noise(scheduler, latents, combined_embeddings, guidance_scale=7.5, timestep=0, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    input = torch.cat([latents] * combined_embeddings.shape[0])
    input = scheduler.scale_model_input(input, timestep)

    # predict the noise residual
    with torch.no_grad():
        pred = unet(input, timestep, encoder_hidden_states=combined_embeddings).sample

    # perform guidance
    pred_uncond, pred_text_original, pred_text_new = pred.chunk(3)
    
    pred_original = pred_uncond + guidance_scale * (pred_text_original - pred_uncond)
    pred_new = pred_uncond + guidance_scale * (pred_text_new - pred_uncond)

    return pred_original, pred_new

The paper states that repeating the noise prediction 10 times allowed for a good mask to be generated, so we will use that value here. We will then loop through and make our noise predictions. Each loop we set a different seed so that the added noise will be different, and then we take the difference in the noise predictions for our two prompts, and then store this difference in our variable all_preds_base. At the end of all the loops we divide all_preds_base by our number of loops to get an average difference over our 10 predictions.

n = 10

with torch.no_grad():
    all_preds_base = torch.zeros_like(original_image_latents)
    original_preds_base = torch.zeros_like(original_image_latents)
    new_preds_base = torch.zeros_like(original_image_latents)
    for i in range(n):
        seed=i*3
        noise = create_noise(scheduler, original_image_latents, start_step=25, seed=seed)
        original_preds, new_preds = predict_noise(scheduler,
                                                  noise,
                                                  combined_embeddings,
                                                  guidance_scale,
                                                  scheduler.timesteps[mask_noise_step],
                                                  seed=seed)
        original_preds_base += original_preds
        new_preds_base += new_preds
        all_preds_base += (new_preds - original_preds).abs()
    all_preds_base /= n
    original_preds_base /= n
    new_preds_base/= n

You can see here that the horse prompt gives a fairly even noise prediction over the ten loops, however the zebra prompt highlights the horse specifically. Since the noise prediction for the two prompts is also very similar for the background, this leaves us just with an outline of the horse when the two are subtracted:

png

> all_preds_base.shape
torch.Size([1, 4, 64, 64])

all_preds_base is currently still the same shape as the initial latents, so we need to combine those four channels into one. We do this by taking the mean of those four channels. After that, we want to normalise these values between 0 and 1, so that the difference numbers are easier to work with when creating the mask.

all_preds = all_preds_base.mean(axis=1)

all_preds = all_preds - all_preds.min()
all_preds = all_preds / all_preds.max()
> all_preds.min(), all_preds.max()
(tensor(0., device='cuda:0', dtype=torch.float16),
    tensor(1., device='cuda:0', dtype=torch.float16))
all_preds = all_preds.squeeze().detach().cpu().numpy()

Now we’ve done some processing, the mask looks like the below image:

png

That’s a pretty good start, however what we really want is a mask which is either present or not at a certain pixel. To achieve this, we will set all values under a threshold as 0 and all above as 1, giving us binary values to work with for the mask. The DiffEdit paper suggests setting this threshold at 0.5, however with our testing we found a lower value was more effective. This may depend on the similarity of your two prompts.

We can try a few different threshold values and visualise the resulting mask. This will allow us to choose a value that works well for our image:

png

0.1 looks like the best threshold we’ve got from the above section as the higher values no longer cover the whole horse. The shape isn’t perfect however so we can do some more processing of the mask to improve its shape.

# set the best threshold value
initial_mask = (all_preds > 0.1).astype(float)

One technique to help with the coverage of the mask is to blur it a small amount. This can help to remove small holes and give a smoother outline on the mask, we again visualise a few different values to see what works best in this case:

png

7px looks like the best option here, it gives a smoother outline around the horse without losing the overall shape or uncovering parts of our image - we’ll use that value:

mask_np = (cv.GaussianBlur(np.array(initial_mask, dtype=np.uint8), (7, 7), 0.) > 0.1).astype(float)

If there are holes in the mask, there are other techniques that can help, such as binary_fill_holes from the scipy.ndimage package. We’ll skip this since our example doesn’t have any holes, but it can be another option to improve the mask further. For now, we’ll stick with the mask we’ve created, which now looks like this:

png

Applying DiffEdit to our original image

Now we have a mask, we can utilise this to ensure that our image generation of our zebra only affects the sections under the mask. We do this using a standard img2img pipeline for a diffusion model, but with one main change: For each step of the denoising process, instead of removing the noise that our diffusion model predicts in the areas outside the mask, we will replace this with a version of our original image with the correct amount of noise. By doing this at each step of the denoising process, this allows the final generated image within the mask to seamlessly integrate with the original image outside of the mask, as we effectively trick it at each step into believing that it generated the background of the original image. We’ll explain exactly how we do this in more detail later.

For now, we want to set up our img2img pipeline in the standard way, with one exception - we want to create two schedulers, one which handles the denoising of the image we are generating as normal, and another which allows us to correctly noise our original image to the same level, which we will then use within the denoising step. We will call these original_image_scheduler, and denoising_scheduler respectively.

Since we will be using an img2img pipeline, we will be passing in our original image latents rather than starting with pure noise. This is because we still want our zebra to have a similar pose to the horse which is present in our original image. We add roughly ~75% noise to our original image latents, by noising the image with a scheduler start_step which is already 25% of the way through the total number of steps, effectively skipping the first 25% of the noising process. This allows some of the original horse image to influence our zebra image generation. We save these new latents in the variable noisy_latents and use these as the starting point for the image generation of our zebra.

The next step is to perform the denoising loop, starting with our noisy latents above and then progressively denoising the image using the prompt "zebra", similar to a standard img2img pipeline. However, at each step of our denoising loop we must ensure that we only change the areas under the mask. We do this by calculating xt which is a version of our original image latents given the correct noise for the next step in the denoising process, and yt which is the slightly denoised latents which the scheduler has created based on our U-Net’s noise prediction, and would typically be the input to the next step of the denoising loop for standard image generation.

Next we combine the two sets of latents, to ensure that yt (the generated latents) is only applied within the masked area, and xt (our original image latents) makes up the rest of the image. We do this using an equation from the DiffEdit paper, and you can see that implemented in the code on the following line:

noisy_latents = mask*yt+((1-mask)*xt)

Our mask is a tensor with either 1 or 0 as its values, where 1 is within the masked area and 0 values for everywhere else. The first part of this code mask*yt is effectively multiplying the latents by 0 outside of the masked area, stopping the generated latents from having any influence on the image outside of the masked area. The second part of the equation: ((1-mask)*xt), which is doing the opposite - 1-mask, makes all of the non-masked areas be equal to 1 and the masked areas equal to 0, so when we multiply xt by this, we reduce the influence of these latents within the masked area to 0. We add the two parts of this equation together to produce the input latents to the next step, where xt makes up the non-masked area of the latents, but yt makes up the masked area. These new noisy_latents then constitute the input to the next step of the denoising loop.

You can see how this all fits together within our denoising loop:

# denoising loop
for i, t in enumerate(tqdm(denoising_scheduler.timesteps[start_step:])):
    current_step = i + start_step
    next_step = current_step + 1

    input = torch.cat([noisy_latents] * 2)
    scaled_input = denoising_scheduler.scale_model_input(input, t)

    # predict the noise residual
    with torch.no_grad(): 
        pred = unet(scaled_input, t, encoder_hidden_states=text_embeddings).sample

    # perform guidance
    pred_uncond, pred_text = pred.chunk(2)
    pred_nonscaled = pred_uncond + guidance_scale * (pred_text - pred_uncond)
    pred = pred_nonscaled

    # compute the "previous" noisy sample
    if (next_step < num_inference_steps):
        xt = create_noise(original_image_scheduler, original_image_latents, start_step=next_step)
    else:
        xt = original_image_latents

    yt = denoising_scheduler.step(pred, t, noisy_latents).prev_sample

    noisy_latents = mask*yt+((1-mask)*xt)

    if callback:
        callback(noisy_latents)
final_latents = noisy_latents

Once the denoising loop is complete, we can decode the final_latents to see what our generated "zebra" DiffEdit image looks like:

png

And there we have it! Our original horse has been completely replaced with our generated zebra image, and through the power of the DiffEdit process, you can see that the background has been untouched outside of the masked area.

We also added a callback so that we could see the denoising process as it progressed. We started at step 12 since we wanted the original image to still guide the pose of our zebra in our output image.

png

Limitations

Something else which is worth considering is the limitations of the DiffEdit mask creation technique. If there are huge differences between the original image caption and the query string then the mask that is generated is not as accurate, which has the knock-on effect of leading to poor final image generation. Take our example below, we’ll use an image of a person holding a bowl, and use "a bowl of fruits" as the caption. We’ll also pass in "sports cars" as the query. As we can see the mask creation includes random areas of the background due to the lack of similarities with the original image and the intended query, and generates a rather odd final image.

caption = "bowl of fruits"
query = "sports cars"

diffedit(fruit_bowl_image, caption, query)

png

png

Conclusion

When working on this DiffEdit pipeline we set out to implement the techniques in the paper using our own custom pipeline. As you can see in the example above, we were able to achieve this with accurate mask generation and decent final image replacement. This shows that DiffEdit is a very powerful technique, and it was surprisingly easy to adapt an img2img pipeline to work with the DiffEdit technique. The ability to pass in just two different text prompts and have an accurate mask generated as a result was truly impressive.

There is still more for us to explore with this technique, from tweaking the various parameters further to tune both the mask creation and the image generation, to further understanding what prompts give the best results. It would also be interesting to adapt this to work with Stable Diffusion 2 or SDXL and see how that affects both the mask creation and final image generation.

An interesting use case for a DiffEdit-like technique is Google’s Magic Eraser, where it automatically creates a mask which highlights people in the background of an image, and then replaces those people with a continuation of the scenery rather than another object. It could be a fun extension task to try and adapt this pipeline to take in an image and try to remove people from the background automatically.

Thanks for reading! If you liked this blog post and want to see the full notebook we used to write the code and generate all of the images (along with some bonus content), then take a look at it on Kaggle.