TL;DR
在 PyTorch 裏,除了官方文件寫明的 in place operations,另外還有+=
或 *=
都是 in-place operations。有可能導致 back propagation 出錯。
關於 in place
最近在寫 PyTorch 的時候踩到一個坑如下,貌似是與 in-place operation 有關的錯誤。
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [24, 1]], which is output 0 of ViewBackward, is at version 14; expected version 13 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
就我的印象所及,in-place operation 只有以前看到的 in-place tanh()
和 scatter_add_()
這兩個東西。但試著把他們排除掉後,錯誤還是持續發生。
把 with torch.autograd.set_detect_anomaly(True):
設定起來後,多顯示出來的出錯的部分卻指向一個再簡單不過的全連接層的 forward lgt = self.fc(cov_t)
,把此行做各種排除還是沒有用。
好在在 PyTorch 論壇裡面找到 Alban Desmaison 大神的回答。原來 +=
和 *=
都是 in-place 運算。
而我原本的 code 裡面有一行是
x1 += k
改成
x1 = x1 + k
就再也沒出現上面的錯誤了。
這樣想一想,PyTorch 確實寫得很 Pythonic。根據 Fluent Python 裡面說的,+=
會先去找有沒有 __iadd__()
(in place addition) 才會再去找 __add__()
。+=
的語意是 in place addition 沒有錯,那自然在 PyTorch 裡面也是用這樣的做法了。
之前都只有想到 function 會有 in-place 的問題,沒有注意到運算子也會有。謹此紀錄。