一、前言
在前兩篇文章〈 Diffusion Model 論文研究與實作心得 Part.1 〉 前言與圖片雜訊前處理 和 〈 Diffusion Model 論文研究與實作心得 Part.2 〉 U-Net 模型架構介紹與實作中,我完成了資料前處理與模型的搭建。因此在Part.3(最終篇)就要來進行模型的訓練和結果呈現。
二、模型訓練
我們可以參考一下ddpm作者的sudo code,這樣對實作的步驟有很大的幫助。
我們的模型輸出是預測圖片的雜訊(對,不是修復後的圖),拿去和加在上面的雜訊進行比較。所以get_loss
函數應該有三個參數,X_0,timestep和model。
def get_loss(x_0, t, model):
pass
比較所需要的有三個東西
- 某個timestep的x
- 實際加上的雜訊
- 模型預測的雜訊
1.和2.可以用〈 Diffusion Model 論文研究與實作心得 Part.1 〉 前言與圖片雜訊前處理 裡定義的
def forward_diffuse_process(x_0, t):
'''
回傳第t個timestep的圖片和加上的雜訊
'''
noise = torch.randn_like(x_0) #回傳與X_0相同size的noise tensor,也就是reparameterization的epsilon
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t]
sqrt_oneminus_alphas_cumprod_t = sqrt_oneminus_alphas_cumprod[t]
#element-wise的運算
return sqrt_alphas_cumprod_t*x_0 + sqrt_oneminus_alphas_cumprod_t*noise, noise
這個函數會回傳前兩點需要的東西。
def get_loss(x_0, t, model):
x_t, noise = forward_diffuse_process(x_0, t)
而3. 則需要使用我們上次架構的U-net模型
def get_loss(x_0, t, model):
x_t, noise = forward_diffuse_process(x_0, t)
noise_prediction = model(x_t, t)
最後對noise和noise_prediction進行比較就能得到Loss了,這邊選用L2 Loss
def get_loss(x_0, t, model):
x_t, noise = forward_diffuse_process(x_0, t)
noise_prediction = model(x_t, t)
return F.l2_loss(noise, noise_prediction)
如此一來就能開始進行訓練了!
optimizer選用Adam,epochs先選用20 (colab的資源讓我一次只敢做這麼多QQ)
from torch.optim import Adam
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 20 # Try more!
for epoch in range(epochs):
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
loss = get_loss(model, batch[0], t)
loss.backward()
optimizer.step()
if epoch % 5 == 0 and step == 0:
print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
訓練output:
Epoch 0 | step 000 Loss: 0.8118380904197693
Epoch 5 | step 000 Loss: 0.2767971158027649
Epoch 10 | step 000 Loss: 0.29156017303466797
Epoch 15 | step 000 Loss: 0.24683958292007446
Epoch 20 | step 000 Loss: 0.22735241055488586
這個專案的心臟,Diffusion Model正式訓練完成 (感動
三、圖片修復與成果呈現
說到底,我們模型輸出的終究只是對雜訊的預測,因此還需要一點點的數學才能將這個雜訊預測用於修復原圖。
還記得第一篇提到的q(Xt|Xt-1)嗎?那是用於破壞照片的forward process,而現在的backward process(修復圖片)ddpm的論文作者使用p(Xt-1|Xt)代表。
這部分牽涉到很複雜的數學(我也不太懂),所以我就放一部份的筆記和完整數學算式的連結
總之經過一點魔法我們能透過最底下框起來的式子計算出前一步timestep的圖。
論文中作者好像"憑經驗"省略了一堆數學還得到更好的結果,所以實作的部分就依照上面的sudo code就行了。
此外,這邊會用到第一篇定義的變數,我放在下面方便理解。
# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)
# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
#新定義的
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
@torch.no_grad() #記得寫這行,在sample的時候才不會逆向傳遞梯度
def sample_timestep(x, t):
"""
給一個被破壞的圖片x和timestep,回傳修復後的圖片
"""
#這邊基本都是按照sudo code的算式
betas_t = get_index_from_list(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
# Call model (current image - noise prediction)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
if t == 0:
return model_mean
else:
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
@torch.no_grad()
def sample_and_plot_image():
#首先,生成隨機雜訊
img_size = IMG_SIZE
img = torch.randn((1, 3, img_size, img_size), device=device)
#這部分是用plt來呈現成果
plt.figure(figsize=(15,15))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)
#從第T個timestep修復到第0個
for i in range(0,T)[::-1]:
t = torch.full((1,), i, device=device, dtype=torch.long)
img = sample_timestep(img, t)
if i % stepsize == 0:
plt.subplot(1, num_images, int(i/stepsize+1))
show_tensor_image(img.detach().cpu()) #第一篇的函式
plt.show()
來看看epoch=80時候的成果:
呃...雖然有點抽象,但多少能看出類似臉、眼睛、頭髮的色塊,如果將epochs調高一點應該能得到更好的成果。
四、結語(系列總結)
第一次寫這種系列文,從資料前處裡到訓練模型,雖然省略了很多細節,很多地方可能做得不夠好,但我對自己踏出的第一步感到挺滿意的。
我之後可能會再寫一篇外傳 (?,講講怎麼改造這個模型,讓他能產出更高畫質的圖片或變成prompt-to-image模型,又或者我搞了一張顯卡把epochs跑完再來看看成果之類的。一樣,都是後話了。
相關資料
https://www.youtube.com/watch?v=a4Yfz2FxXiY
https://www.youtube.com/watch?v=HoKDTa5jHvg&t=1338s
https://huggingface.co/blog/annotated-diffusion
https://arxiv.org/pdf/2102.09672.pdf
https://arxiv.org/pdf/1503.03585.pdf
https://arxiv.org/pdf/2006.11239.pdf
https://theaisummer.com/latent-variable-models/#reparameterization-trick
https://theaisummer.com/diffusion-models/
https://brohrer.mcknote.com/zh-
https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#nice