ModelNetDataLoader及部分数据示例


点击查看代码
'''
@author: Xu Yan
@file: ModelNet.py
@time: 2021/3/19 15:51
'''
import os
import numpy as np
import warnings
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset

import argparse
def parse_args():
    '''argparse使用步骤:
    1、创建 ArgumentParser() 对象
    2、调用 add_argument() 方法添加参数
    3、使用 parse_args() 解析添加的参数

    default:  对于参数,default的值用于选项字符串没有出现在命令行中的时候
    type: 可以设置传入参数要求的类型
    choices: 可以设置填入的参数在 choices 指定的范围内
    help: 填写该参数背后的一些帮助信息
    '''
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('training')                                                                   #description - 在参数帮助文档之前显示的文本(默认无,无则显示文件名)
    parser.add_argument('--use_cpu', action='store_true', default=True, help='use cpu mode')                       #action='store_true'-让此值(--use_cpu)直接默认设置为 bool 值
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
    parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]')
    parser.add_argument('--num_category', default=40, type=int, choices=[10, 40],  help='training on ModelNet10/40')
    parser.add_argument('--epoch', default=2, type=int, help='number of epoch in training')
    parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
    parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
    parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
    parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
    parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
    parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
    parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
    parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
    return parser.parse_args()
args=parse_args()


warnings.filterwarnings('ignore')


def pc_normalize(pc):                               #point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
    centroid = np.mean(pc, axis=0)                  #np.mean(,axis)-axis=0压缩行,对各列求均值;=1压缩列,对各行求均值
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))      #np.max-返回数组中最大的数,也可指定axis
    pc = pc / m
    return pc


def farthest_point_sample(point, npoint):
    """
    Input:
        xyz: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:,:3]
    centroids = np.zeros((npoint,))
    distance = np.ones((N,)) * 1e10
    farthest = np.random.randint(0, N)              #随机确定一点
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :]
        dist = np.sum((xyz - centroid) ** 2, -1)    #np.sum(,axis)-axis默认与缺失则将数组元素全加,为0压缩行将每一列相加压缩为一行,此时为-1将每一行相加压缩为一列
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = np.argmax(distance, -1)          #np.argmax(array,axis)-返回一个numpy数组中最大值的索引值
    point = point[centroids.astype(np.int32)]
    return point


class ModelNetDataLoader(Dataset):                  #torch.utils.data.Dataset-自定义数据读取的方法
    def __init__(self, root, args, split='train'):
        self.root = root
        self.npoints = args.num_point
        self.process_data = args.process_data
        self.uniform = args.use_uniform_sample
        self.use_normals = args.use_normals
        self.num_category = args.num_category

        if self.num_category == 10:
            self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')                #self.catfile-包含了所有标签的文件
        else:
            self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')

        self.cat = [line.rstrip() for line in open(self.catfile)]                               #rstrip() 删除 string 字符串末尾的指定字符,默认为空白符,包括空格、换行符、回车符、制表符
        self.classes = dict(zip(self.cat, range(len(self.cat))))                                #dict+zip-快速创建字典,此时将字符串形式的标签转化为数字标签,此时生成不同模型对应的标签

        shape_ids = {}                                                                          #shape_ids-此时要是一个包含训练和测试文件名的字典
        if self.num_category == 10:
            shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
            shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
        else:
            shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
            shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]

        assert (split == 'train' or split == 'test')                                            #assert-断言,作用是如果它的条件返回错误,则终止程序执行
        shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]                  #返回一个包含所有文件名(即模型名,去除了序号)的列表,‘_’.join-因为模型中有night_stand
        self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
                         in range(len(shape_ids[split]))]                                       #返回了单个模型文件名(该文件包含了一个模型的完整点数据)组成的列表,列表组成[(模型名,单个模型文件的路径),...]
        print('The size of %s data is %d' % (split, len(self.datapath)))


        if self.uniform:
            self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints)) #设置fps(最远点采样)的数据的文件保存路径,.dat文件-数据文件,程序使用的数据
        else:
            self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))     #设置UniformSampling(均匀采样)的数据文件保存路径

        if self.process_data:                                                                   #process_data-加工数据,收集采样点
            if not os.path.exists(self.save_path):                                              #os.path.exists()—判断括号里的文件是否存在,可以是文件路径
                print('Processing data %s (only running in the first time)...' % self.save_path)
                self.list_of_points = [None] * len(self.datapath)
                self.list_of_labels = [None] * len(self.datapath)

                for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
                    fn = self.datapath[index]                                                   #此时的fn是一个元组,(文件名,文件路径)
                    cls = self.classes[self.datapath[index][0]]                                 #cls-获取单个模型的标签
                    cls = np.array([cls]).astype(np.int32)                                      #将int型cls转换成只含有单个元素的数组形式的cls
                    point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)             #delimiter-分隔符,默认是空格,loadtxt的功能是读入数据文件且要求每一行数据的格式相同,读1和2维数组的文本文件(写入是savetxt)
                                                                                                #此时的point_set是包含列表(每个子列表代表一个点,有6个元素)的列表,从文件数据中可知维度是10000*6
                    if self.uniform:                                                            #选取采样点,采样完成后的point_set的维度是self.npoints*6
                        point_set = farthest_point_sample(point_set, self.npoints)              #使用fps采样
                    else:                                                                       #使用UniformSampling采样
                        point_set = point_set[0:self.npoints, :]                                #这句代码这么暴力,怎么像是均匀采样?

                    self.list_of_points[index] = point_set                                      #收集所有准备训练的模型数据,此时的list_of_points的维度是len(self.datapath)*self.npoints*6
                    self.list_of_labels[index] = cls                                            #收集所有参加训练的模型标签,此时的list_of_labels的维度是len(self.datapath)*1

                with open(self.save_path, 'wb') as f:
                    pickle.dump([self.list_of_points, self.list_of_labels], f)                  #pickle.dump(obj, file)-将对象obj保存到文件file中去,若无file则自动创建
            else:
                print('Load processed data from %s...' % self.save_path)                        #Load processed data-加载已经加工后的数据
                with open(self.save_path, 'rb') as f:
                    self.list_of_points, self.list_of_labels = pickle.load(f)                   #pickle.load(f)-从file中读取一个字符串,并将它重构为原来的python对象

    def __len__(self):                                                                          #一个类表现得像一个list,要获取有多少个元素,就得用 len() 函数
        return len(self.datapath)                                                               #只要正确实现了__len__()方法,就可以用len()函数返回实例的长度

    def _get_item(self, index):
        if self.process_data:
            point_set, label = self.list_of_points[index], self.list_of_labels[index]
        else:
            fn = self.datapath[index]
            cls = self.classes[self.datapath[index][0]]
            label = np.array([cls]).astype(np.int32)
            point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
            if self.uniform:
                point_set = farthest_point_sample(point_set, self.npoints)
            else:
                point_set = point_set[0:self.npoints, :]
                
        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
        if not self.use_normals:
            point_set = point_set[:, 0:3]

        return point_set, label[0]

    def __getitem__(self, index):                           #该方法定义了使用索引值来查找元素的方法
        return self._get_item(index)


if __name__ == '__main__':
    import torch

    data = ModelNetDataLoader('/home/yanhua/PycharmProjects/Pointnet_Pointnet2_pytorch-master/data/modelnet40_normal_resampled/',args,split='train')   #要单独运行该文件,最好加绝对路径
    DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)       #torch.utils.data.DataLoader-封装了Data对象,实现单(多)进程迭代器输出数据集
    for point, label in DataLoader:
        print(point.shape)
        print(label)

之前以为ModelNet40里单个类别的数据表示的是同一个物体,如飞机类,最初我以为是只有一种飞机,其实有多种,如下图所示

相关