Mxnet 输出中间特征图
如果使用gluon动态图的方法进行训练的直接将对应的中间变量return就好了,这里说明一下Mxnet使用加载已训练模型的方法输出中间特征图。
model_path = r'/home/PycharmProjects/densenet'
sym, arg_params, aux_params = mx.model.load_checkpoint(model_path, 300)
args = sym.get_internals() # 获得所有中间输出
print(args.list_outputs())
conv1 = args['conv21_fwd_output'] # args的keys值由args.list_outputs()获得
relu1 = args['relu32_fwd_output'] #
group = mx.symbol.Group([conv1, relu1]) # Group中list的参数个数决定了后面的prob的shape
mod = mx.mod.Module(symbol=group, context=mx.gpu()) # 创建Module
mod.bind(for_training=False, data_shapes=[('data', (1, 1, 128, 128, 128))])
mod.set_params(arg_params, aux_params)
Batch = namedtuple('Batch', ['data'])
image, _, _, mhd_dict = lk.load_itk(r'/home/image0.mhd', mhd_dict=True)
image = image.expand_dims(axis=0).expand_dims(axis=0)
mod.forward(Batch([image]))
prob = mod.get_outputs()
print(prob[0].shape) # prob是一个装有shape=[1, 1, 128, 128, 128]Ndarray的list
create_feature_map(prob[0][0], mhd_dict, 10) # 由于是医学图像,将输出数据保存为raw和相应的.mhd文件
print('done')
def create_feature_map(prob, mhd_dict, feature_num, random_select=True):
feature_list = [i for i in range(prob.shape[0])]
if random_select:
random.shuffle(feature_list)
if feature_num > prob.shape[0]:
feature_num = prob.shape[0]
for i in range(feature_num):
feature_map = prob[feature_list[i]]
_min = nd.min(feature_map)
_max = nd.max(feature_map)
feature_map = ((feature_map - _min) / (_max - _min)) * 255
with open('feature_map%d.raw' %i, 'wb') as f:
f.write(feature_map.asnumpy().astype(np.short))
mhd_dict['ElementDataFile'] = 'feature_map%d.raw' % i
mhd_dict['DimSize'] = list(reversed(feature_map.shape))
lk.write_dict('feature_map%d.mhd' %i, mhd_dict)