These are supplementary notes for the HuggingFace diffusion models class unit 1. The video below is a walkthrough of unit 1 plus these notes.
What does a diffusion model do?
Step 2: Preparing the training dataset
Why do we have to do these data augmentations and what are they doing?
# Define data augmentations
= transforms.Compose(
preprocess
[# Resize
transforms.Resize((image_size, image_size)), # Randomly flip (data augmentation)
transforms.RandomHorizontalFlip(), # Convert to tensor (0, 1)
transforms.ToTensor(), 0.5], [0.5]), # Map to (-1, 1)
transforms.Normalize([
] )
- images are usually pixels with 8-bit integers between 0 and 255 for each of the RGB channels, but neural networks typically prefer floating point values in a smaller range
Resize
: resizing so that we can train our neural net faster with a smaller image size (smaller data to work with)RandomHorizontalFlip
: randomly flip data so that we have more data to work withToTensor
: converts our input image totorch.FloatTensor
of shape (C x H x W) in the scale of [0.0, 1].- we neeed to convert it to a tensor so that it’s in a data format Pytorch can use.
C
is channel,H
is height,W
is width. We have 3 channels, one for each of RGB.
- we neeed to convert it to a tensor so that it’s in a data format Pytorch can use.
Normalize
: scales our tensors to a smaller range with mean0.5
and standard deviation0.5
, since neural networks work better with smaller range of floating point values
Step 3: Define the scheduler
What does this equation mean?
\(q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t\mathbf{I}) \quad q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1})\)
- remember the input to our model is a ‘noisy’ image, so we have to add noise to our input images each timestep
- the amount of noise you add each step is taken from this normal distribution ( \(\mathcal{N}\) ) with mean \(\sqrt{1 - \beta_t}\) and variance \(\beta_t\) at timestep
t
Using a reparameritization trick, it means we’re adding this variance to our input images at each timestemp
\(q(\mathbf{x}_t) = \sqrt{1-\beta_t} * x_{t-1} + \sqrt{\beta_t} * \epsilon\)
Here’s what that looks like in code for the forward pass: (thanks J.M. Tomczak - code)
class DDGM(nn.Module):
def __init__(self, p_dnns, decoder_net, beta, T, D):
super(DDGM, self).__init__()
self.p_dnns = p_dnns # a list of sequentials
self.decoder_net = decoder_net
# other params
self.D = D
self.T = T
self.beta = torch.FloatTensor([beta])
@staticmethod
def reparameterization(mu, log_var):
= torch.exp(0.5*log_var)
std = torch.randn_like(std)
eps return mu + std * eps
def reparameterization_gaussian_diffusion(self, x, i):
return torch.sqrt(1. - self.beta) * x + torch.sqrt(self.beta) * torch.randn_like(x)
def forward(self, x, reduction='avg'):
# =====
# forward difussion
= [self.reparameterization_gaussian_diffusion(x, 0)]
zs
for i in range(1, self.T):
self.reparameterization_gaussian_diffusion(zs[-1], i)) zs.append(
- take a look specifically at
reparameterization_gaussian_diffusion
, which is the code implementation for our reparameterized function above
Step 4: Define the Model
What is the U-net model?
- originally used for biomedical image segmentation
- image segmentation is a task where you try to identify each pixel of an image to a certain class
- the images below show the respective input image and segmented output
Here is what the code for the Unet model looks like: (thanks ShawnBIT)
class UNet(nn.Module):
def __init__(self, in_channels=1, n_classes=2, feature_scale=2, is_deconv=True, is_batchnorm=True):
super(UNet, self).__init__()
self.in_channels = in_channels
self.feature_scale = feature_scale
self.is_deconv = is_deconv
self.is_batchnorm = is_batchnorm
= [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]
filters
# downsampling
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
# upsampling
self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)
# final conv (without any concat)
self.final = nn.Conv2d(filters[0], n_classes, 1)
# initialise weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
='kaiming')
init_weights(m, init_typeelif isinstance(m, nn.BatchNorm2d):
='kaiming')
init_weights(m, init_type
def forward(self, inputs):
= self.conv1(inputs) # 16*512*512
conv1 = self.maxpool(conv1) # 16*256*256
maxpool1
= self.conv2(maxpool1) # 32*256*256
conv2 = self.maxpool(conv2) # 32*128*128
maxpool2
= self.conv3(maxpool2) # 64*128*128
conv3 = self.maxpool(conv3) # 64*64*64
maxpool3
= self.conv4(maxpool3) # 128*64*64
conv4 = self.maxpool(conv4) # 128*32*32
maxpool4
= self.center(maxpool4) # 256*32*32
center = self.up_concat4(center,conv4) # 128*64*64
up4 = self.up_concat3(up4,conv3) # 64*128*128
up3 = self.up_concat2(up3,conv2) # 32*256*256
up2 = self.up_concat1(up2,conv1) # 16*512*512
up1
= self.final(up1)
final
return final
- the forward pass does downsampling first and then upsampling
- each downsampling step the dimensions are reduced by half using
MaxPool2d