pytorch网络转libtorch常见问题
目录
一、RuntimeError: All inputs of range must be ints, found Tensor in argument 0:
二、RuntimeError: Sliced expression not yet supported for subscripted assignment. File a bug if you want this:
三、RuntimeError: Tried to access nonexistent attribute or method 'len' of type 'torch.torch.nn.modules.container.ModuleList'. Did you forget to initialize an attribute in init()?
四、RuntimeError: Expected integer literal for index
五、RuntimeError: Arguments for call are not valid. The following variants are available:
一、All inputs of range must be ints, found Tensor in argument 0:
问题
参数类型不正确,函数的默认参数是tensor
解决措施
函数传入参数不是tensor需要注明类型
我的问题是传入参数npoint是一个int类型,没有注明会报错,更改如下:
由
def test(npoint):
...
更改为
def test(npoint: int):
...
二、Sliced expression not yet supported for subscripted assignment. File a bug if you want this:
问题
不支持赋值给切片表达式
解决措施
根据自己需求,进行修改,可利用循环替代
我将view_shape[1:] = [1] * (len(view_shape) - 1)更改为
for i in range(1, len(view_shape)):
view_shape[i] = 1
三、Tried to access nonexistent attribute or method 'len' of type 'torch.torch.nn.modules.container.ModuleList'. Did you forget to initialize an attribute in init()?
问题
forward函数中好像不支持len(nn.ModuleList())和下标访问
解决措施
如果是一个ModuleList()可以用enumerate函数,多个同维度的可以用zip函数
我这里有两个ModuleList(),所以采用zip函数,更改如下:
由
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
更改为
for conv, bn in zip(self.mlp_convs, self.mlp_bns):
new_points = F.relu(bn(conv(new_points)))
ref: https://github.com/pytorch/pytorch/issues/16123
四、Expected integer literal for index
问题和解决方法类似第三个
五、Arguments for call are not valid. The following variants are available
Expected a value of type 'List[Tensor]' for argument 'indices' but instead found type 'List[Optional[Tensor]]'
问题
赋值类型不对,需求是tensor,但给的是int
解决措施
- 方法1
将int类型的数N用torch.tensor(N)代替
由
mask = sqrdists > radius ** 2
group_idx[mask] = N
变为
mask = sqrdists > radius ** 2
group_idx[mask] = torch.tensor(N)
- 方法2 (速度较慢)
用for循环替代`
由
mask = sqrdists > radius ** 2
group_idx[mask] = N
变为
B, rows, cols = sqrdists.shape
ref_redius = radius ** 2
for b in range(B):
for r in range(rows):
print("r: ", r)
for c in range(cols):
if sqrdists[b][r][c] > ref_redius:
group_idx[b][r][c] = N