model.apply(fn)或net.apply(fn)
详情可参考:https://pytorch.org/docs/1.11/generated/torch.nn.Module.html?highlight=torch%20nn%20module%20apply#torch.nn.Module.apply
首先,我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module,也就是模块。
pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。
比如下面的网络例子中。net这个模块有两个子模块,分别为Linear(2,4)和Linear(4,8)。函数首先对Linear(2,4)和Linear(4,8)两个子模块调用init_weights函数,即print(m)打印Linear(2,4)和Linear(4,8)两个子模块。然后再对net模块进行同样的操作。如此完成递归地调用。从而完成model.apply(fn)或者net.apply(fn)。
个人水平有限,不足处望指正。
参考链接:https://blog.csdn.net/qq_37025073/article/details/106739513
@torch.no_grad() def init_weights(m): print(m) if type(m) == nn.Linear: m.weight.fill_(1.0) print(m.weight) net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) net.apply(init_weights)
#输出: Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[ 1., 1.], [ 1., 1.]]) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[ 1., 1.], [ 1., 1.]]) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )
https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=nn%20module%20apply#torch.nn.Module.apply
如果我们想对某些特定的子模块submodule
做一些针对性的处理,该怎么做呢。我们可以加入type(m) == nn.Linear:
这类判断语句,从而对子模块m进行处理。如下,读者可以细细体会一下。
import torch.nn as nn @torch.no_grad() def init_weights(m): print(m) if type(m) == nn.Linear: m.weight.fill_(1.0) print(m.weight) net = nn.Sequential(nn.Linear(2,4), nn.Linear(4, 8)) print(net) print('isinstance torch.nn.Module',isinstance(net,torch.nn.Module)) print(' ') net.apply(init_weights)
可以先打印网络整体看看。调用apply
函数后,先逐一打印子模块m,然后对子模块进行判断,打印Linear
这类子模块m
的权重。
#输出: Sequential( (0): Linear(in_features=2, out_features=4, bias=True) (1): Linear(in_features=4, out_features=8, bias=True) ) isinstance torch.nn.Module True Linear(in_features=2, out_features=4, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=4, out_features=8, bias=True) Parameter containing: tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=4, bias=True) (1): Linear(in_features=4, out_features=8, bias=True) )
网友:所以说apply函数是有顺序的,先在子模块上操作,最后在父模块上操作。