一、前言
在上一篇文章〈 Diffusion Model 論文研究與實作心得 Part.1 〉 前言與圖片雜訊前處理中,我完成了對圖片加入雜訊的部分,因此接下來就輪到模型的拆解。
二、U-Net 模型簡介
圖片來源:【Deep Learning for Image Segmentation: U-Net Architecture】
在DDPM論文中,作者使用了U-Net這種模型架構來進行訓練。U-Net是Auto-encoder的變種,可以看到下方一樣有一個bottleneck的部分,且輸入和輸出圖片的大小相同。U-Net在image segmantation的領域有著重大貢獻,與傳統的Auto-encoder不同的是,U-Net在encoder和decoder之間有使用residual connection,以更好的保留原始圖片的特徵。
三、U-Net 架構實作
若要進行U-Net的實作,可以拆解成下方幾個的零件實作。
- 兩層CNN的Block
- time embedding
- Down(左半邊的Encoder,兩層CNN加上Maxpooling)
- Up(右半邊的Decoder,兩層CNN加上Upsample)
- self attention
- residual connection
1. 雙層CNN
先從最常用到的著手,先設計一個有兩層CNN的Block,在之後的地方都會用到
class DoubleConv(nn.Module):
def __init__(self):
pass
def forward(self):
pass
填入模型
class DoubleConv(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.GroupNorm(1, out_c), #equivalent with LayerNorm
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
nn.GroupNorm(1, out_c), #equivalent with LayerNorm
nn.ReLU()
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
2. Time Embedding
在訓練U-Net的時候,我一開始以為輸入是一張圖片,輸出只要給出被修復過的圖片就好。但其實這樣有一個問題,就是模型不知道不同timestep的圖片之間的差別,導致模型需要直接面對不同雜訊強度的圖片並進行修復。
embedding的概念簡單來說就是把一個單獨的值加工成一個tensor。比如我們對模型輸入圖片和一個整數(timestep),我們能透過embedding將那個整數換成一個tensor,變成讓模型更容易學習的形式。而DDPM的作者選擇使用Sinusoidal Position Embedding來為單獨timestep做embedding。
看起來很厲害的Sinusoidal Position Embedding
(圖源:A Gentle Introduction to Positional Encoding in Transformer Models, Part 1)
這個問題有點像Transformer在訓練的時候用attention訓練時,需要將文字再加上一個positional embedding的概念相同,我們也需要為不同雜訊強度的圖片加上一個time embedding來告訴模型這是甚麼強度的圖片。
def pos_encoding(t, channels):
t = torch.tensor([t])
inv_freq = 1.0 / (
10000
** (torch.arange(0, channels, 2).float() / channels)
)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc
接收兩個整數後回傳一個embedded好的Tensor,範例如下
pos_encoding(10, 16) #timestep = 10
tensor([[-0.5440, -0.0207, 0.8415, 0.3110, 0.0998, 0.0316, 0.0100, 0.0032,
-0.8391, -0.9998, 0.5403, 0.9504, 0.9950, 0.9995, 0.9999, 1.0000]])
當然這樣一個tensor肯定不能直接與圖片的tensor相加,在size上還需要調整,這個在後面會有提到。
3. Down & Up
接下來是Down和Up,簡單概念就是進行Maxpooling或Upsample後再加個DoubleConv
首先是Down的部分
class Down(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.down = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_c,out_c,first_residual=True),
)
def forward(self, x):
x = self.down(x)
return x
基本架構差不多是這樣,但是不要忘了我們還要為圖片加上time embedding
class Down(nn.Module):
def __init__(self, in_c, out_c, emb_dim=128):
super().__init__()
self.down = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_c,out_c),
)
self.emb_layer = nn.Sequential(
nn.ReLU(),
nn.Linear(emb_dim, out_c),
)
def forward(self, x, t):
x = self.down(x)
#擴充兩個dimension,然後使用repeat填滿成和圖片相同(如同numpy.tile)
t_emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + t_emb
Up的架構基本相同,但是如果看上面的圖,可以看到Up還需要接收一個類似residual connection的輸入,所以在forward()裡面會多一個skip_x
與x
接起來。
class Up(nn.Module):
def __init__(self, in_c, out_c, emb_dim=128):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = DoubleConv(in_c,out_c)
self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(emb_dim, out_c),
)
def forward(self, x, skip_x, t):
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
4. Self Attention Block
這個部分沒打算細講(因為我也沒完全懂),之後可能會再寫一篇Attention is all you need的研究心得之類的。簡單來說Self Attention可以想成輸入一個向量,結果再輸出一個向量的黑盒子。(這邊直接照抄Outlier的程式碼)
class SelfAttention(nn.Module):
def __init__(self, channels, size):
super(SelfAttention, self).__init__()
self.channels = channels
self.size = size
self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)
def forward(self, x):
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
5. 組裝U-Net
最後我們把來把上面寫的東西組裝起來
class UNet(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=128, device="cuda"):
super().__init__()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64) #(b,3,64,64) -> (b,64,64,64)
self.down1 = Down(64, 128) #(b,64,64,64) -> (b,128,32,32)
self.sa1 = SelfAttention(128, 32) #(b,128,32,32) -> (b,128,32,32)
self.down2 = Down(128, 256) #(b,128,32,32) -> (b,256,16,16)
self.sa2 = SelfAttention(256, 16) #(b,256,16,16) -> (b,256,16,16)
self.down3 = Down(256, 256) #(b,256,16,16) -> (b,256,8,8)
self.sa3 = SelfAttention(256, 8) #(b,256,8,8) -> (b,256,8,8)
self.bot1 = DoubleConv(256, 512) #(b,256,8,8) -> (b,512,8,8)
self.bot2 = DoubleConv(512, 512) #(b,512,8,8) -> (b,512,8,8)
self.bot3 = DoubleConv(512, 256) #(b,512,8,8) -> (b,256,8,8)
self.up1 = Up(512, 128) #(b,512,8,8) -> (b,128,16,16) because the skip_x
self.sa4 = SelfAttention(128, 16) #(b,128,16,16) -> (b,128,16,16)
self.up2 = Up(256, 64) #(b,256,16,16) -> (b,64,32,32)
self.sa5 = SelfAttention(64, 32) #(b,64,32,32) -> (b,64,32,32)
self.up3 = Up(128, 64) #(b,128,32,32) -> (b,64,64,64)
self.sa6 = SelfAttention(64, 64) #(b,64,64,64) -> (b,64,64,64)
self.outc = nn.Conv2d(64, c_out, kernel_size=1) #(b,64,64,64) -> (b,3,64,64)
def pos_encoding(self, t, channels):
t = torch.tensor([t])
inv_freq = 1.0 / (
10000
** (torch.arange(0, channels, 2).float() / channels)
)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc
def forward(self, x, t):
t = t.unsqueeze(-1).type(torch.float)
t = self.pos_encoding(t, self.time_dim)
#initial conv
x1 = self.inc(x)
#Down
x2 = self.down1(x1, t)
x2 = self.sa1(x2)
x3 = self.down2(x2, t)
x3 = self.sa2(x3)
x4 = self.down3(x3, t)
x4 = self.sa3(x4)
#Bottle neck
x4 = self.bot1(x4)
x4 = self.bot2(x4)
x4 = self.bot3(x4)
#Up
x = self.up1(x4, x3, t)
x = self.sa4(x)
x = self.up2(x, x2, t)
x = self.sa5(x)
x = self.up3(x, x1, t)
x = self.sa6(x)
#Output
output = self.outc(x)
return output
確認一下是否能正常運作以及輸出是否正確
sample = torch.randn((1, 3, 64, 64))
t = torch.tensor([10])
model = UNet()
model(sample, t).shape
Output:
torch.Size([1, 3, 64, 64])
水喔,U-Net 模型的部分搞定了
四、結語
本來想多講一點的(圖片修復的部分)但寫到這裡已經快8000字了,下個部分沒意外應該就是完結了,看能不能寫完圖片修復和模型訓練。可能會再額外寫一篇Extra講如何改進什麼的,都是後話了。
相關資料
https://www.youtube.com/watch?v=a4Yfz2FxXiY
https://www.youtube.com/watch?v=HoKDTa5jHvg&t=1338s
https://huggingface.co/blog/annotated-diffusion
https://arxiv.org/pdf/2102.09672.pdf
https://arxiv.org/pdf/1503.03585.pdf
https://arxiv.org/pdf/2006.11239.pdf
https://theaisummer.com/latent-variable-models/#reparameterization-trick
https://theaisummer.com/diffusion-models/
https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/