Walkthrough: Diffusion Models Class - Unit 1

A walkthrough of HuggingFace’s Diffusion models class - unit 1



Nov 2022

These are supplementary notes for the HuggingFace diffusion models class unit 1. The video below is a walkthrough of unit 1 plus these notes.

What does a diffusion model do?

Diagram of diffusion steps

Step 2: Preparing the training dataset

Why do we have to do these data augmentations and what are they doing?

# Define data augmentations
preprocess = transforms.Compose(
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
  • images are usually pixels with 8-bit integers between 0 and 255 for each of the RGB channels, but neural networks typically prefer floating point values in a smaller range
  • Resize: resizing so that we can train our neural net faster with a smaller image size (smaller data to work with)
  • RandomHorizontalFlip: randomly flip data so that we have more data to work with
  • ToTensor: converts our input image to torch.FloatTensor of shape (C x H x W) in the scale of [0.0, 1].
    • we neeed to convert it to a tensor so that it’s in a data format Pytorch can use. C is channel, H is height, W is width. We have 3 channels, one for each of RGB.
  • Normalize: scales our tensors to a smaller range with mean 0.5 and standard deviation 0.5, since neural networks work better with smaller range of floating point values

Step 3: Define the scheduler

What does this equation mean?

\(q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t\mathbf{I}) \quad q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1})\)

  • remember the input to our model is a ‘noisy’ image, so we have to add noise to our input images each timestep
  • the amount of noise you add each step is taken from this normal distribution ( \(\mathcal{N}\) ) with mean \(\sqrt{1 - \beta_t}\) and variance \(\beta_t\) at timestep t

Using a reparameritization trick, it means we’re adding this variance to our input images at each timestemp

\(q(\mathbf{x}_t) = \sqrt{1-\beta_t} * x_{t-1} + \sqrt{\beta_t} * \epsilon\)

Here’s what that looks like in code for the forward pass: (thanks J.M. Tomczak - code)

class DDGM(nn.Module):
    def __init__(self, p_dnns, decoder_net, beta, T, D):
        super(DDGM, self).__init__()

        self.p_dnns = p_dnns  # a list of sequentials

        self.decoder_net = decoder_net

        # other params
        self.D = D

        self.T = T

        self.beta = torch.FloatTensor([beta])

    def reparameterization(mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def reparameterization_gaussian_diffusion(self, x, i):
        return torch.sqrt(1. - self.beta) * x + torch.sqrt(self.beta) * torch.randn_like(x)

    def forward(self, x, reduction='avg'):
        # =====
        # forward difussion
        zs = [self.reparameterization_gaussian_diffusion(x, 0)]

        for i in range(1, self.T):
            zs.append(self.reparameterization_gaussian_diffusion(zs[-1], i))
  • take a look specifically at reparameterization_gaussian_diffusion, which is the code implementation for our reparameterized function above

Step 4: Define the Model

What is the U-net model?

  • originally used for biomedical image segmentation
  • image segmentation is a task where you try to identify each pixel of an image to a certain class
  • the images below show the respective input image and segmented output

U-Net output

Here is what the code for the Unet model looks like: (thanks ShawnBIT)

class UNet(nn.Module):

    def __init__(self, in_channels=1, n_classes=2, feature_scale=2, is_deconv=True, is_batchnorm=True):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.feature_scale = feature_scale
        self.is_deconv = is_deconv
        self.is_batchnorm = is_batchnorm

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
        self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
        # upsampling
        self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
        self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
        self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
        self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)
        # final conv (without any concat)
        self.final = nn.Conv2d(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        conv1 = self.conv1(inputs)           # 16*512*512
        maxpool1 = self.maxpool(conv1)       # 16*256*256
        conv2 = self.conv2(maxpool1)         # 32*256*256
        maxpool2 = self.maxpool(conv2)       # 32*128*128

        conv3 = self.conv3(maxpool2)         # 64*128*128
        maxpool3 = self.maxpool(conv3)       # 64*64*64

        conv4 = self.conv4(maxpool3)         # 128*64*64
        maxpool4 = self.maxpool(conv4)       # 128*32*32

        center = self.center(maxpool4)       # 256*32*32
        up4 = self.up_concat4(center,conv4)  # 128*64*64
        up3 = self.up_concat3(up4,conv3)     # 64*128*128
        up2 = self.up_concat2(up3,conv2)     # 32*256*256
        up1 = self.up_concat1(up2,conv1)     # 16*512*512

        final = self.final(up1)

        return final
  • the forward pass does downsampling first and then upsampling
  • each downsampling step the dimensions are reduced by half using MaxPool2d