matplotlib subplot 绘制多张图


def draw_images_grid_with_labels(data, nrows, figsize=(12, 12), **subplots_adjust_param):
    import matplotlib.pyplot as plt

#     fig, axes = plt.subplots(nrows, len(data) // nrows)
#     for idx, item in enumerate(data):
#         i = idx % nrows # Get subplot row
#         j = idx // nrows # Get subplot column
# #         print(data[idx][0][0].shape)
#         img = data[idx][0][0].permute(1, 2, 0)
#         label = data[idx][1]
#         axes[i, j].imshow(img)
#         axes[i, j].set_title(label)
#     plt.subplots_adjust(wspace=0, hspace=0)
# #     plt.tight_layout()
#     plt.show()

    _, axes = plt.subplots(nrows, len(data) // nrows, figsize=figsize)
    axes = axes.flatten()
    for idx, (img, axe) in enumerate(zip(data, axes)):
        img = data[idx][0][0].permute(1, 2, 0)
        label = data[idx][1]
        axe.imshow(img)
        axe.set_title(label)

    plt.subplots_adjust(**subplots_adjust_param)
#     plt.subplot_tool()
    plt.show()
params = {
    "left": 0.125,  # the left side of the subplots of the figure
    "right": 0.9,   # the right side of the subplots of the figure
    "bottom": 0,  # the bottom of the subplots of the figure
    "top": 0.3,     # the top of the subplots of the figure
    "wspace": 0.3,  # the amount of width reserved for space between subplots,
                  # expressed as a fraction of the average axis width
    "hspace": 0.3,  # the amount of height reserved for space between subplots,
                  # expressed as a fraction of the average axis height
}
draw_images_grid_with_labels(res, 2, **params)

需要根据需求调整 params,实例如下