写点什么

PyTorch:常见错误 inplace operation

作者:强劲九
  • 2022 年 1 月 30 日
  • 本文字数:1265 字

    阅读完需:约 4 分钟

inplace 操作是 PyTorch 里面一个比较常见的错误,有的时候会比较好发现,例如下面的代码:


 import torch w = torch.rand(4, requires_grad=True) w += 1 loss = w.sum() loss.backward()
复制代码


执行 loss 对参数 w 进行求导,会出现报错:RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.


导致这个报错的主要是第 3 行代码 w += 1,如果把这句改成 w = w + 1,再执行就不会报错了。这种写法导致的 inplace operation 是比较好发现的,但是有的时候同样类似的报错,会比较不好发现。例如下面的代码:


 import torch x = torch.zeros(4) w = torch.rand(4, requires_grad=True) x[0] = torch.rand(1) * w[0] for i in range(3):     x[i+1] = torch.sin(x[i]) * w[i] loss = x.sum() loss.backward()
复制代码


执行之后会出现报错:


 >>> RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor []], which is output 0 of SelectBackward, is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
复制代码


根据提示我们可以使用 with torch.autograd.set_detect_anomaly(True) 来帮助我们定位具体的出错位置(这个方法会花费比较长的时间)。


 with torch.autograd.set_detect_anomaly(True):     x = torch.zeros(4)     w = torch.rand(4, requires_grad=True)     x[0] = torch.rand(1) * w[0]     for i in range(3):         x[i+1] = torch.sin(x[i]) * w[i]     loss = x.sum()     loss.backward()
复制代码


运行会增加这些报错:


 >>> /Users/strongnine/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py:130: UserWarning: Error detected in SinBackward. Traceback of forward call that caused the error:
复制代码


可以看到出现了 Error detected in SinBackward.,这句描述,我们可以猜测大概是 torch.sin() 这个函数出现了问题。实际上,这个报错的解决办法,就是将第 6 行代码 x[i+1] = torch.sin(x[i]) * w[i] 改成 x[i+1] = torch.sin(x[i].clone()) * w[i],就行了。


 import torch x = torch.zeros(4) w = torch.rand(4, requires_grad=True) x[0] = torch.rand(1) * w[0] for i in range(3):     x[i+1] = torch.sin(x[i].clone()) * w[i] loss = x.sum() loss.backward()
复制代码


总结一下,遇到 inplace operation 的报错,一般可以通过:


  • x += 1 改成 x = x + 1

  • x[:, :, 0:3] = x[:, :, 0:3] + 1 改成 x[:, :, 0:3] = x[:, :, 0:3].clone() + 1

  • x[i+1] = torch.sin(x[i]) * w[i] 改成 x[i+1] = torch.sin(x[i].clone()) * w[i]


如果自己检查不出是哪里出现了问题,可以使用 with torch.autograd.set_detect_anomaly(True) 来帮助我们定位具体的出错位置,但是要注意的是这个方法一般会运行比较长的时间。


参考:

知乎:PyTorch debug 整理;

发布于: 2022 年 01 月 30 日阅读数: 19
用户头像

强劲九

关注

学习是变有趣的第一步。 2018.09.17 加入

公众号:strongnine

评论

发布
暂无评论
PyTorch:常见错误 inplace operation