torch.tensor 和 torch.Tensor


最近在学习pytorch时,再写一个很简单的代码时一直提示报错,最后终于发现是因为我使用的torch.tensor的原因,换成torch.Tensor问题就解决了。

为了加深学习,这里将二者的区别给出:

  • 在PyTorch文档它被写torch.Tensor是一个别名torch.FloatTensor。分别使用二者时如下:
>>> torch.Tensor([1,2,3]).dtype
torch.float32
>>> torch.tensor([1, 2, 3]).dtype
Out[32]: torch.int64
>>> torch.Tensor([True, False]).dtype
torch.float32
>>> torch.tensor([True, False]).dtype
torch.uint8
  • 这种情况是由于torch.tensor自动推断类型,而torch.Tensor默认上全局返回torch.FloatTensor。如果想更改类型,建议使用torch.tensor,它也有类似的参数dtype。