train_classification


点击查看代码
#下面的代码中的一些超参数并没有按照原论文中来,不过总体结构差不多
"""
Author: Benny
Date: Nov 2019
"""
import os
import sys
import torch
import numpy as np

import datetime                                                        #获取当前日期和时间
import logging                                                         #创建日志类并使用
import provider                                                        #一个包含众多功能函数的程序文件
import importlib                                                       #动态导入模块实现代码(需要在程序的运行过程时才能决定导入某个文件中的模块时)
import shutil                                                          #用于操作文件夹或者文件
import argparse                                                        #用于命令项选项与参数解析的模块,直接在命令行中就可以向程序中传入参数并让程序运行

from pathlib import Path                                               #面向对象的文件系统路径(os能干的该模块也能做且命令更加简洁)
from tqdm import tqdm                                                  #一个快速,可扩展的Python进度条,在 Python 长循环中添加一个进度提示信息,只需要封装任意的迭代器tqdm(iterator)
from data_utils.ModelNetDataLoader import ModelNetDataLoader

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR                                                    #BASE_DIR,ROOT_DIR输出都是/home/yanhua/PycharmProjects/Pointnet_Pointnet2_pytorch-master
sys.path.append(os.path.join(ROOT_DIR, 'models'))                      #这句代码使得可以实现model = importlib.import_module(args.model)


def parse_args():
    '''argparse使用步骤:
    1、创建 ArgumentParser() 对象
    2、调用 add_argument() 方法添加参数
    3、使用 parse_args() 解析添加的参数

    default:  对于参数,default的值用于选项字符串没有出现在命令行中的时候
    type: 可以设置传入参数要求的类型
    choices: 可以设置填入的参数在 choices 指定的范围内
    help: 填写该参数背后的一些帮助信息
    '''
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('training')                                                                   #description - 在参数帮助文档之前显示的文本(默认无,无则显示文件名)
    parser.add_argument('--use_cpu', action='store_true', default=True, help='use cpu mode')                       #action='store_true'-让此值(--use_cpu)直接默认设置为 bool 值
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
    parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]')
    parser.add_argument('--num_category', default=40, type=int, choices=[10, 40],  help='training on ModelNet10/40')
    parser.add_argument('--epoch', default=100, type=int, help='number of epoch in training')
    parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
    parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
    parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
    parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
    parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')                               #利用权重衰减可以防止过拟合
    parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')                   #选择是否使用点的全部特征,还是只取xyz
    parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')            #保存/加载采样的点
    parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')  #选择的采样点的方式-采样点方式相同还是最远点采样
    return parser.parse_args()


def inplace_relu(m):                                                              #该函数的功能是节省内存,类似nn.ReLU(inplace=True)
    classname = m.__class__.__name__                                              #获取类名
    if classname.find('ReLU') != -1:                                              #如果类名当中有Relu的话
        m.inplace=True                                                            #inplace-选择是否进行覆盖运算


def test(model, loader, num_class=40):
    mean_correct = []
    class_acc = np.zeros((num_class, 3))
    classifier = model.eval()                                                     #model.eval()-保证BN用全部训练数据的均值和方差,对于Dropout,则是利用到所有网络连接

    for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):

        if not args.use_cpu:
            points, target = points.cuda(), target.cuda()

        points = points.transpose(2, 1)                                          #transpose相当于数学中的转置,即行与列相互调换位置,执行完后points的维度是B*D*N
        pred, _ = classifier(points)                                             #此时的'_'根据网络可知代表的是trans_feat,而pred的维度是[batchsize,num_category]
        pred_choice = pred.data.max(1)[1]                                        #(1)表按行(二维数据)求最大值,(0)则是列,[1]表返回最大值的索引,[0]表返回最大值的每个数
                                                                                 #target.cpu()-将数据移至CPU中,返回值是cpu上的Tensor
        for cat in np.unique(target.cpu()):                                      #np.unique-去除数组中的重复数字,并进行排序之后输出,若不是一维数组则会展开
            classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()#此时的pred_choice和target的shape都是[batchsize],classacc代表在每批中对cat类别预测正确的个数,这句代码很巧妙
            class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])          #统计每个类别的批次准确率之和,[0]的目的是为了取出一个整型,该整型代表cat类别在此批次中的数量
            class_acc[cat, 1] += 1                                                                 #统计出现cat类别的批次数量

        correct = pred_choice.eq(target.long().data).cpu().sum()                                   #同时计算每批次所有类别准确个数之和
        mean_correct.append(correct.item() / float(points.size()[0]))                              #获取每批次的总体准确率

    class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]                          #计算每个类别在一个epoch后的准确率,计算方式为:类别在每批次的准确率之和/批次数量
    class_acc = np.mean(class_acc[:, 2])                                         #求平均分类准确率(accuracy avg.class),计算方式为:所有类别在一个epoch中的准确率之和/num_category
    instance_acc = np.mean(mean_correct)                                         #求整体的准确率(accuracy overall)
                                                                                 #class_acc是在每个epoch中,分别计算每个类别的准确率,再平均求总体的准确率;而
                                                                                 #instance_acc是同时计算所有正确的个数,再一次性求总体准确率
    return instance_acc, class_acc


def main(args):
    def log_string(str):                                                         #可打印日志信息,原来下面代码中的log_string()函数是自己定义的,而不是log模块内含的
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''                                                        #environ是一个字符串所对应环境的映像对象
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))            #得到标准的时间格式并字符串化
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)                                                 #exist_ok:只有在目录不存在时创建目录,目录已存在时不会抛出异常。
    exp_dir = exp_dir.joinpath('classification')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG,详细信息见https://www.jb51.net/article/193299.html'''
    args = parse_args()
    logger = logging.getLogger("Model")                                                       #logging.getLogger(name)方法进行初始化logger这个日志对象
    logger.setLevel(logging.INFO)                                                             #logger.setLevel-设置日志等级
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')     #对输出消息的格式进行设置
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))                   #将log写入文件
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)                                                      #设置文件的log格式
    logger.addHandler(file_handler)                                                           #logger日志对象加载FileHandler对象
    log_string('PARAMETER ...')
    log_string(args)

    '''DATA LOADING'''
    log_string('Load dataset ...')
    data_path = 'data/modelnet40_normal_resampled/'

    train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train')          #Dataset负责整理数据
    test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test')            #num_workers决定了有几个进程来处理data loading
    trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)#drop_last-丢不丢弃最后无法整除而剩余的样本
    testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10) ##Dataloader负责在整理好的数据中按照一定的规则取出batch_size个数据来供网络训练使用

    '''MODEL LOADING'''
    num_class = args.num_category
    model = importlib.import_module(args.model)                                           #上头的sys.path.append()协助实现 
    shutil.copy('./models/%s.py' % args.model, str(exp_dir))                              #下面三个shutil是将用到的py文件复制到log中            
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_classification.py', str(exp_dir))

    classifier = model.get_model(num_class, normal_channel=args.use_normals)
    criterion = model.get_loss()
    classifier.apply(inplace_relu)                                                        #ReLU与inplace原地操作,旧内存上直接更改数值,在使用原地操作前,我们要确定其是贯序的,该方法可以节省内存

    if not args.use_cpu:
        classifier = classifier.cuda()
        criterion = criterion.cuda()

    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,                                                                    #eps-该参数是非常小的数,其为了防止在实现中除以零(如 10E-8)
            weight_decay=args.decay_rate                                                  #weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
                                                                                          #这里将gamma从0.7改为论文中的0.5(is divided by 2)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)       #torch.optim.lr_scheduler-根据epoch训练次数来调整学习率,gamma(float)-更新lr的乘法因子
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        mean_correct = []
        classifier = classifier.train()                                                    #启用Batch Normalization和Dropout,保证BN层用每一批数据的均值和方差,随机取一部分网络连接来训练更新参数

        scheduler.step()
        for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): #smoothing-速度估计的指数移动平均平滑因子,enumerate中的0表示下标起始位置
            optimizer.zero_grad()
            points = points.data.numpy()
            points = provider.random_point_dropout(points)                                 #随机丢弃点,通过用第一个点代替实现
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])       #对每批数据的每个点的每个特征都进行了随机缩放,批之间缩放不一致,批内则一致
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])              #对每批数据的每个点的每个特征都进行了随机平移,批之间平移不一致,批内相同维度一致,不同维度不一致
            points = torch.Tensor(points)                                                  #此时的points的维度是batchsize*N*D
            points = points.transpose(2, 1)                                                #此时的points的维度是B*D*N

            if not args.use_cpu:
                points, target = points.cuda(), target.cuda()

            pred, trans_feat = classifier(points)                                          #此时的pred维度是batchsize*num_category
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]                                              #max(1)[1]-(1):按行求最大值,[1]:返回最大值的每个索引,[0]:返回最大值的每个数

            correct = pred_choice.eq(target.long().data).cpu().sum()                       #eq-比较tensor中对应数据是否相等(相等为True,否则为False),sum()-此时统计相等的个数
            mean_correct.append(correct.item() / float(points.size()[0]))                  #mean_correct-计算并收集每批数据训练完后的准确率
            loss.backward()
            optimizer.step()
            global_step += 1                                                               #global_step可统计训练完后总共运行了多少次batch处理,这手算一下就行啊。。。

        train_instance_acc = np.mean(mean_correct)                                         #计算每个epoch的准确率
        log_string('Train Instance Accuracy: %f' % train_instance_acc)

        with torch.no_grad():                                                              #with语句适用于对资源进行访问的场合,确保不管使用过程中是否发生异常都会执行必要的“清理”操作,释放资源
            instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)   #classifier.eval()-此时BN层和dropout层不会在推断时有效
                                                                                           #with torch.no_grad-该模块下,所有计算得出的tensor的requires_grad都自动设置为False,节约了显存或者说内存
            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')


if __name__ == '__main__':
    args = parse_args()
    main(args)

相关