tensor.detach()


x = torch.tensor(2.0)
x.requires_grad_(True)
y = 2 * x
z = 5 * x

w = y + z.detach()
w.backward()

print(x.grad)

=> 2

本来应该x的梯度为7,但是detach()那一路切段了梯度的传播,导致5没有向后传递

相关