pytorch权重转wts格式,用于tensorrt权重加载
若使用tensorrt加载wts格式,需将模型训练的pt、pth、ckpt等格式权重转换为wts。
但因简单,我只记录此代码,供读者使用。
def checkpint2wts(model, wts_file):
'''
model:模型,需要权重
wts_file:保存wts权重路径,如result.wts
'''
import struct
model_state_dict = model.state_dict() # 此处需要根据模型情况获得state_dict
with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model_state_dict().keys())))
for k, v in model_state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f', float(vv)).hex())
f.write('\n')