# 〈 Diffusion Model 論文研究與實作心得 Part.2 〉 U-Net 模型架構介紹與實作


一、前言

在上一篇文章〈 Diffusion Model 論文研究與實作心得 Part.1 〉 前言與圖片雜訊前處理中,我完成了對圖片加入雜訊的部分,因此接下來就輪到模型的拆解。

二、U-Net 模型簡介

圖片來源:【Deep Learning for Image Segmentation: U-Net Architecture】

在DDPM論文中,作者使用了U-Net這種模型架構來進行訓練。U-Net是Auto-encoder的變種,可以看到下方一樣有一個bottleneck的部分,且輸入和輸出圖片的大小相同。U-Net在image segmantation的領域有著重大貢獻,與傳統的Auto-encoder不同的是,U-Net在encoder和decoder之間有使用residual connection,以更好的保留原始圖片的特徵。

三、U-Net 架構實作

若要進行U-Net的實作,可以拆解成下方幾個的零件實作。

  • 兩層CNN的Block
  • time embedding
  • Down(左半邊的Encoder,兩層CNN加上Maxpooling)
  • Up(右半邊的Decoder,兩層CNN加上Upsample)
  • self attention
  • residual connection

1. 雙層CNN

先從最常用到的著手,先設計一個有兩層CNN的Block,在之後的地方都會用到

class DoubleConv(nn.Module):
  def __init__(self):
    pass

  def forward(self):
    pass

填入模型

class DoubleConv(nn.Module):
  def __init__(self, in_c, out_c):
    super().__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
        nn.GroupNorm(1, out_c), #equivalent with LayerNorm
        nn.ReLU()
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
        nn.GroupNorm(1, out_c), #equivalent with LayerNorm
        nn.ReLU()
    )

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    return x

2. Time Embedding

在訓練U-Net的時候,我一開始以為輸入是一張圖片,輸出只要給出被修復過的圖片就好。但其實這樣有一個問題,就是模型不知道不同timestep的圖片之間的差別,導致模型需要直接面對不同雜訊強度的圖片並進行修復。

embedding的概念簡單來說就是把一個單獨的值加工成一個tensor。比如我們對模型輸入圖片和一個整數(timestep),我們能透過embedding將那個整數換成一個tensor,變成讓模型更容易學習的形式。而DDPM的作者選擇使用Sinusoidal Position Embedding來為單獨timestep做embedding。

看起來很厲害的Sinusoidal Position Embedding
(圖源:A Gentle Introduction to Positional Encoding in Transformer Models, Part 1)

這個問題有點像Transformer在訓練的時候用attention訓練時,需要將文字再加上一個positional embedding的概念相同,我們也需要為不同雜訊強度的圖片加上一個time embedding來告訴模型這是甚麼強度的圖片。

def pos_encoding(t, channels):
  t = torch.tensor([t])
  inv_freq = 1.0 / (
    10000
    ** (torch.arange(0, channels, 2).float() / channels)
  )
  pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
  pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
  pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
  return pos_enc

接收兩個整數後回傳一個embedded好的Tensor,範例如下

pos_encoding(10, 16) #timestep = 10
tensor([[-0.5440, -0.0207,  0.8415,  0.3110,  0.0998,  0.0316,  0.0100,  0.0032,
         -0.8391, -0.9998,  0.5403,  0.9504,  0.9950,  0.9995,  0.9999,  1.0000]])

當然這樣一個tensor肯定不能直接與圖片的tensor相加,在size上還需要調整,這個在後面會有提到。

3. Down & Up

接下來是Down和Up,簡單概念就是進行Maxpooling或Upsample後再加個DoubleConv

首先是Down的部分

class Down(nn.Module):
  def __init__(self, in_c, out_c):
    super().__init__()
    self.down = nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_c,out_c,first_residual=True),
    )

  def forward(self, x):
    x = self.down(x)
    return x

基本架構差不多是這樣,但是不要忘了我們還要為圖片加上time embedding

class Down(nn.Module):
  def __init__(self, in_c, out_c, emb_dim=128):
    super().__init__()
    self.down = nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_c,out_c),
    )

    self.emb_layer = nn.Sequential(
        nn.ReLU(),
        nn.Linear(emb_dim, out_c),
    )

  def forward(self, x, t):
    x = self.down(x)
    #擴充兩個dimension,然後使用repeat填滿成和圖片相同(如同numpy.tile)
    t_emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) 
    return x + t_emb

Up的架構基本相同,但是如果看上面的圖,可以看到Up還需要接收一個類似residual connection的輸入,所以在forward()裡面會多一個skip_xx接起來。

class Up(nn.Module):
    def __init__(self, in_c, out_c, emb_dim=128):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = DoubleConv(in_c,out_c)
        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_c),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

4. Self Attention Block

這個部分沒打算細講(因為我也沒完全懂),之後可能會再寫一篇Attention is all you need的研究心得之類的。簡單來說Self Attention可以想成輸入一個向量,結果再輸出一個向量的黑盒子。(這邊直接照抄Outlier的程式碼)

class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)

5. 組裝U-Net

最後我們把來把上面寫的東西組裝起來

class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=128, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim

        self.inc = DoubleConv(c_in, 64) #(b,3,64,64) -> (b,64,64,64)

        self.down1 = Down(64, 128) #(b,64,64,64) -> (b,128,32,32)
        self.sa1 = SelfAttention(128, 32) #(b,128,32,32) -> (b,128,32,32)
        self.down2 = Down(128, 256) #(b,128,32,32) -> (b,256,16,16)
        self.sa2 = SelfAttention(256, 16) #(b,256,16,16) -> (b,256,16,16)
        self.down3 = Down(256, 256) #(b,256,16,16) -> (b,256,8,8)
        self.sa3 = SelfAttention(256, 8) #(b,256,8,8) -> (b,256,8,8)

        self.bot1 = DoubleConv(256, 512) #(b,256,8,8) -> (b,512,8,8)
        self.bot2 = DoubleConv(512, 512) #(b,512,8,8) -> (b,512,8,8)
        self.bot3 = DoubleConv(512, 256) #(b,512,8,8) -> (b,256,8,8)

        self.up1 = Up(512, 128) #(b,512,8,8) -> (b,128,16,16) because the skip_x
        self.sa4 = SelfAttention(128, 16) #(b,128,16,16) -> (b,128,16,16)
        self.up2 = Up(256, 64) #(b,256,16,16) -> (b,64,32,32)
        self.sa5 = SelfAttention(64, 32) #(b,64,32,32) -> (b,64,32,32)
        self.up3 = Up(128, 64) #(b,128,32,32) -> (b,64,64,64)
        self.sa6 = SelfAttention(64, 64) #(b,64,64,64) -> (b,64,64,64)

        self.outc = nn.Conv2d(64, c_out, kernel_size=1) #(b,64,64,64) -> (b,3,64,64)

    def pos_encoding(self, t, channels):
        t = torch.tensor([t])
        inv_freq = 1.0 / (
         10000
         ** (torch.arange(0, channels, 2).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        #initial conv
        x1 = self.inc(x)

        #Down
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        #Bottle neck
        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        #Up
        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)

        #Output
        output = self.outc(x)
        return output

確認一下是否能正常運作以及輸出是否正確

sample = torch.randn((1, 3, 64, 64))
t = torch.tensor([10])

model = UNet()
model(sample, t).shape

Output:

torch.Size([1, 3, 64, 64])

水喔,U-Net 模型的部分搞定了

四、結語

本來想多講一點的(圖片修復的部分)但寫到這裡已經快8000字了,下個部分沒意外應該就是完結了,看能不能寫完圖片修復和模型訓練。可能會再額外寫一篇Extra講如何改進什麼的,都是後話了。

相關資料

https://www.youtube.com/watch?v=a4Yfz2FxXiY
https://www.youtube.com/watch?v=HoKDTa5jHvg&t=1338s
https://huggingface.co/blog/annotated-diffusion
https://arxiv.org/pdf/2102.09672.pdf
https://arxiv.org/pdf/1503.03585.pdf
https://arxiv.org/pdf/2006.11239.pdf
https://theaisummer.com/latent-variable-models/#reparameterization-trick
https://theaisummer.com/diffusion-models/
https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/

#AI #Deep Learning #Diffusion Model







你可能感興趣的文章

安裝 Go 環境

安裝 Go 環境

[重新理解 C++]  TMP(1): compiling time recursion

[重新理解 C++] TMP(1): compiling time recursion

Npm 套件收納箱

Npm 套件收納箱






留言討論