Pytorch的view方法


参考链接:

https://blog.csdn.net/scut_salmon/article/details/82391320

结论:Pytorch里的view方法用于改变数据维度,与numpy的reshape方法类似。

一.按照传入数字使数据维度进行转换

示例

a = torch.arange(1, 17)  # a's shape is (16,)
 
a.view(4, 4) # output below
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])
[torch.FloatTensor of size 4x4]
 
a.view(2, 2, 4) # output below
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8]],
 
        [[ 9, 10, 11, 12],
         [13, 14, 15, 16]]])
[torch.FloatTensor of size 2x2x4]

二.传入数字-1,自动对维度进行变换

在某一个维度,我们可以传入数字-1,自动对维度进行计算并变化:

假设我们有一个数据维度为【3,5,2】的tensor,我们想要将其转化为其中两个维度分别为【3,1】,【5,2】,而剩下的第三个维度自动进行计算,那么我们可以使用-1来代替【3,1,10】当中的10,以及用-1来代替转化后【5,2,3】维度当中的数字3.我们可以发现3110=352=523,因此变化后的维度乘积是相等的。

示例

import torch
a=torch.randn(3,5,2)
print(a)
print(a.view(3,1,-1).size())
print(a.view([3,1,-1]).size()) #不管加不加上列表符号,最后reshape的结果是一样的
print(a.view([5,2,-1]).size())

结果

tensor([[[ 1.6498, -0.4354],
         [-1.0042, -0.1582],
         [ 1.2794, -0.1203],
         [ 0.9198,  2.8475],
         [ 0.0065,  1.5481]],

        [[ 0.7220, -1.1230],
         [ 0.2665, -0.6645],
         [-0.6159, -0.3833],
         [-1.4767,  0.8378],
         [-0.3257,  0.2394]],

        [[ 0.3784,  0.4233],
         [-0.5807,  1.2695],
         [ 1.7632,  0.7828],
         [ 1.0076,  0.6205],
         [ 0.9948, -1.2256]]])
torch.Size([3, 1, 10])
torch.Size([3, 1, 10])
torch.Size([5, 2, 3])

三、使用view可以进行数据降维

只要元素数量能够对应,我们即可使用view来进行数据降维。

示例

import torch
a = torch.randn(3,5,2)
print(a.size())
print(a.view([3,-1]).size())

结果

torch.Size([3, 5, 2])
torch.Size([3, 10])