PyTorch grad_fn的作用以及RepeatBackward, SliceBackward示例


来源   

 

变量.grad_fn表明该变量是怎么来的,用于指导反向传播。例如loss = a+b,则loss.gard_fn为,表明loss是由相加得来的,这个grad_fn可指导怎么求a和b的导数。

程序示例:

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 import torch   w1 = torch.tensor(2.0, requires_grad=True) = torch.tensor([[1.2.], [3.4.]], requires_grad=True) tmp = a[0, :] tmp.retain_grad()   # tmp是非叶子张量,需用.retain_grad()方法保留导数,否则导数将会在反向传播完成之后被释放掉 = tmp.repeat([31]) b.retain_grad() loss = (b * w1).mean() loss.backward()   print(b.grad_fn)    # 输出: print(b.grad)       # 输出: tensor([[0.3333, 0.3333],                     #               [0.3333, 0.3333],                     #               [0.3333, 0.3333]])   print(tmp.grad_fn)    # 输出: print(tmp.grad)       # 输出:tensor([1., 1.])     print(a.grad)     # 输出:tensor([[1., 1.],                   #              [0., 0.]])

手动推导:

手动推导的结果和程序的结果是一致的。

相关