PyTorch grad_fn的作用以及RepeatBackward, SliceBackward示例
来源
变量.grad_fn表明该变量是怎么来的,用于指导反向传播。例如loss = a+b,则loss.gard_fn为
程序示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
w1 = torch.tensor( 2.0 , requires_grad = True )
a = torch.tensor([[ 1. , 2. ], [ 3. , 4. ]], requires_grad = True )
tmp = a[ 0 , :]
tmp.retain_grad() # tmp是非叶子张量,需用.retain_grad()方法保留导数,否则导数将会在反向传播完成之后被释放掉
b = tmp.repeat([ 3 , 1 ])
b.retain_grad()
loss = (b * w1).mean()
loss.backward()
print (b.grad_fn) # 输出:
print (b.grad) # 输出: tensor([[0.3333, 0.3333],
# [0.3333, 0.3333],
# [0.3333, 0.3333]])
print (tmp.grad_fn) # 输出:
print (tmp.grad) # 输出:tensor([1., 1.])
print (a.grad) # 输出:tensor([[1., 1.],
# [0., 0.]])
|
手动推导:
手动推导的结果和程序的结果是一致的。