OpenCV使用GMM实现图像分割


一、概述

  案例:使用GMM机器学习算法实现图像分割

  相关API介绍:

Ptr emModel = EM::create();//创建EM实例
emModel->setClustersNumber(numCluster);//设置分类个数
emModel->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);//设置协方差矩阵模型
emModel->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));//设置停止条件
emModel->trainEM(points,noArray(),labels,noArray());//训练,其中labels中存放的是训练好后的分类标签,同一个分类标签相同

  实现分割算法的步骤:

    1.载入图像

    2.准备训练样本数据:将载入的图像转为CV_32FC或者CV_64FC1类型的

    3.创建EM实例并训练样本是数据,训练好后输出分类标签

    4.创建一个CV_8UC3通道的类型的Mat用于存放,标记后的像素点

    5.根据labels将分类后的像素存入第4步中的mat中

    6.输出图像

二、代码演示

  Mat src = imread(filePath);
    if(src.empty()){
        qDebug()<<"载入图像为空";
        return;
    }

    imshow("src",src);

    //获取原始图像的宽、高、通道数
    int width = src.cols;
    int height = src.rows;
    int dims = src.channels();

    //将原始图像数据转为double类型
    int numSamples = width*height;//总共的像素点个数
    Mat points(numSamples,dims,CV_64FC1);//存储转换后浮点数据的容器
    Mat labels;//分类标签
    //填充样本像素数据
    int index = 0;
    for(int row = 0;row){
        for(int col = 0;col){
            index = width*row+col;
            Vec3b rgb = src.at(row,col);
            points.at<double>(index,0)= static_cast<int>(rgb[0]);
            points.at<double>(index,1)= static_cast<int>(rgb[1]);
            points.at<double>(index,2)= static_cast<int>(rgb[2]);
        }
    }

    //
    int numCluster = 3;//3分类
    Scalar clolors[] = {
        Scalar(255, 0, 0),
        Scalar(0, 255, 0),
        Scalar(0, 0, 255),
        Scalar(255, 255, 0)
    };
    //使用emm分类模型进行分类
    Ptr emModel = EM::create();
    emModel->setClustersNumber(numCluster);//设置分类个数
    emModel->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);//设置协方差矩阵模型
    emModel->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));//设置停止条件
    emModel->trainEM(points,noArray(),labels,noArray());//训练

    //根据标签分类
    Mat result = Mat::zeros(src.size(),CV_8UC3);

    for(int row=0;row){
        for(int col=0;col){
            index = row*width+col;
            int label = labels.at<int>(index,0);
            Scalar color = clolors[label];
            result.at(row,col)[0] = color[0];
            result.at(row,col)[1] = color[1];
            result.at(row,col)[2] = color[2];
        }
    }

    imshow("result",result);

三、示例图片