PyTorch 踩坑紀錄:one of the variables needed for gradient computation has been modified by an inplace operation


TL;DR

在 PyTorch 裏,除了官方文件寫明的 in place operations,另外還有+=*= 都是 in-place operations。有可能導致 back propagation 出錯。


關於 in place

最近在寫 PyTorch 的時候踩到一個坑如下,貌似是與 in-place operation 有關的錯誤。

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [24, 1]], which is output 0 of ViewBackward, is at version 14; expected version 13 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

就我的印象所及,in-place operation 只有以前看到的 in-place tanh()scatter_add_() 這兩個東西。但試著把他們排除掉後,錯誤還是持續發生。

with torch.autograd.set_detect_anomaly(True): 設定起來後,多顯示出來的出錯的部分卻指向一個再簡單不過的全連接層的 forward lgt = self.fc(cov_t),把此行做各種排除還是沒有用。

好在在 PyTorch 論壇裡面找到 Alban Desmaison 大神的回答。原來 +=*= 都是 in-place 運算。

而我原本的 code 裡面有一行是

x1 += k

改成

x1 = x1 + k

就再也沒出現上面的錯誤了。


這樣想一想,PyTorch 確實寫得很 Pythonic。根據 Fluent Python 裡面說的,+= 會先去找有沒有 __iadd__()(in place addition) 才會再去找 __add__()+= 的語意是 in place addition 沒有錯,那自然在 PyTorch 裡面也是用這樣的做法了。


之前都只有想到 function 會有 in-place 的問題,沒有注意到運算子也會有。謹此紀錄。

#PyTorch #torch #Machine-Learning






你可能感興趣的文章

跟爺爺奶奶們度過開心快樂的夏令營

跟爺爺奶奶們度過開心快樂的夏令營

Fetch 與 Promise (五):async 與 await

Fetch 與 Promise (五):async 與 await

Client Side Rendering(CSR) 的缺點與優化

Client Side Rendering(CSR) 的缺點與優化






留言討論