从零构建自己的远控•IOCP服务器构建(8)


#include "pch.h"
#include "framework.h"
#include 
#include 
#pragma comment(lib,"ws2_32.lib")

#include "CLock.h"
#include "CIOCPServer.h"
#include "../common/lz4/lz4.h"

//传送协议规则
//第一个字节 数据包是否加密
//int  4字节  数据包原大小
//int  4字节  数据包压缩后大小
//int  4字节  数据包最大缓冲区
#define HDR_SIZE    13

//先声明临界区对象 防止找不到对象
CRITICAL_SECTION CIOCPServer::m_cs;
/// 
/// 初始化类对象成员
/// 
CIOCPServer::CIOCPServer()
{
    WSADATA wsaData;
    WSAStartup(0x202, &wsaData);         //初始化套接字2.2
    //初始化临界区对象
    InitializeCriticalSection(&m_cs);

    m_hThread = NULL;
    m_socListen = NULL;

    m_bTimeToKill = false;
    m_bDisconnectAll = false;

    m_hEvent = NULL;
    m_hCompletionPort = NULL;

    m_bInit = false;
    m_nCurrentThreads = 0;
    m_nBusyThreads = 0;

    m_nSendKbps = 0;
    m_nRecvKbps = 0;

    m_nMaxConnections = 10000;
    m_nKeepLiveTime = 1000 * 60 * 3; // 三分钟探测一次
}

CIOCPServer::~CIOCPServer()
{
}


/// 
/// 初始化监听端口
/// 
/// 最大连接数
/// 端口
/// bool
bool CIOCPServer::Initialize(int nMaxConnections, int nPort)
{

    m_nMaxConnections = nMaxConnections;
    //创建一个重叠模型socket初始化套接字WSA_FLAG_OVERLAPPED
    m_socListen = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
    //初始化 套接字
    if (m_socListen == INVALID_SOCKET)
    {
        return false;
    }
    //完成端口初始化
    InitializeIOCP();


#pragma region 绑定端口&监听

    SOCKADDR_IN        saServer;
    //大小端转换 绑定端口
    saServer.sin_port = htons(nPort);

    //绑定空地址
    saServer.sin_family = AF_INET;
    saServer.sin_addr.s_addr = INADDR_ANY;

    // 绑定端口 套接字
    int  nRet = bind(m_socListen,
        (LPSOCKADDR)&saServer,
        sizeof(struct sockaddr));

    if (nRet == SOCKET_ERROR)
    {
        closesocket(m_socListen);
        return false;
    }

    // 监听
    nRet = listen(m_socListen, SOMAXCONN);
    if (nRet == SOCKET_ERROR)
    {
        closesocket(m_socListen);
        return false;
    }

#pragma endregion

#pragma region 事件模型线程池
    // 创建一个事件模型
    m_hEvent = WSACreateEvent();

    if (m_hEvent == WSA_INVALID_EVENT)
    {
        closesocket(m_socListen);
        return false;
    }

    //绑定当前事件和套接字  FD_ACCEPT 应用程序想接收与进入连接有关的通知
    nRet = WSAEventSelect(m_socListen,
        m_hEvent,
        FD_ACCEPT);

    if (nRet == SOCKET_ERROR)
    {
        closesocket(m_socListen);
        return false;
    }

    //开启监听线程 
    m_hThread = (HANDLE)_beginthreadex(NULL, 0, ListenThreadProc, (void*)this, 0, NULL);

#pragma endregion

    return false;
}

/// 
/// 客户端连接线程
/// 
/// this
/// 
unsigned WINAPI  CIOCPServer::ListenThreadProc(LPVOID lParam)
{
    //强转当前处理iocp类
    CIOCPServer* pThis = reinterpret_cast(lParam);

    WSANETWORKEVENTS events;
    while (1)//循环等待接收连接处理信号
    {
        DWORD dwRet;
        //是否有客户端连接 1 数量 2 对象事件
        dwRet = WSAWaitForMultipleEvents(1, &pThis->m_hEvent,
            FALSE,
            1000,//超时时间
            FALSE);
        //事件对象超时 重新等待
        if (dwRet == WSA_WAIT_TIMEOUT)
            continue;

        int nRet = WSAEnumNetworkEvents(pThis->m_socListen,
            pThis->m_hEvent,
            &events);
        if (nRet == SOCKET_ERROR)
        {
            break;
        }
        //是否有客户端连接事件
        if (events.lNetworkEvents & FD_ACCEPT)
        {
            pThis->OnAccept();//连接处理代码
        }
    }
    return 0;
}

/// 
/// 连接处理
/// 
void CIOCPServer::OnAccept()
{

    SOCKADDR_IN    SockAddr;
    SOCKET        clientSocket;

    int            nRet;
    int            nLen;
    //退出
    if (m_bTimeToKill || m_bDisconnectAll)
        return;

    //连接  在事件处理过后 代表有人连接时 再去缓冲区拿数据
    nLen = sizeof(SOCKADDR_IN);
    clientSocket = accept(m_socListen,
        (LPSOCKADDR)&SockAddr,
        &nLen);

    if (clientSocket == SOCKET_ERROR)
    {
        nRet = WSAGetLastError();
        if (nRet != WSAEWOULDBLOCK)
        {
            return;
        }
    }

    //初始化
    ClientContext* pContext = new ClientContext;
    if (pContext == NULL)
        return;

    pContext->m_Socket = clientSocket;

    pContext->m_wsaInBuffer.buf = (char*)pContext->m_byInBuffer;
    pContext->m_wsaInBuffer.len = sizeof(pContext->m_byInBuffer);
    //完成端口 传递pcontext  初始化
    HANDLE h = CreateIoCompletionPort((HANDLE)clientSocket, m_hCompletionPort, (DWORD)pContext, 0);
    if (!(m_hCompletionPort == h))
    {
        delete pContext;
        pContext = NULL;

        closesocket(clientSocket);
        closesocket(m_socListen);
        return;
    }

    const char chOpt = 1;
    // Set KeepAlive 开启保活机制 给每个创建的客户端开启保存获取包
    setsockopt(pContext->m_Socket, SOL_SOCKET, SO_KEEPALIVE, &chOpt, sizeof(char));

    // 设置超时详细信息 保活探测
    tcp_keepalive    klive;
    klive.onoff = 1; // 启用保活
    klive.keepalivetime = m_nKeepLiveTime;
    klive.keepaliveinterval = 1000 * 10; // 重试间隔为10秒 
    unsigned long  lgpt=NULL;
    //设置套接字模式
    int t=WSAIoctl
    (
        pContext->m_Socket,
        SIO_KEEPALIVE_VALS,
        &klive,
        sizeof(tcp_keepalive),
        NULL,
        0,
        &lgpt,
        0,
        NULL
    );

    CLock cs(m_cs);//加锁
    //当前客户端连接加入队列
    m_listContexts.AddTail(pContext);

    //完成端口 向端口发送数据告知有客户端成功连接 功能占时没写
    //BOOL bSuccess = PostQueuedCompletionStatus(m_hCompletionPort, 0, (DWORD)pContext, NULL);
    ////客户端是否断开连接
    //if ((!bSuccess && GetLastError() != ERROR_IO_PENDING))
    //{
    //    //清除客户端
    //    RemoveStaleClient(pContext, TRUE);
    //    return;
    //}

    PostRecv(pContext);
}

/// 
/// 接收数据
/// 
/// 
void CIOCPServer::PostRecv(ClientContext* pContext)
{
    // issue a read request 
    OVERLAPPEDPLUS* pOverlap = new OVERLAPPEDPLUS(IORead);
    ULONG            ulFlags = MSG_PARTIAL;
    DWORD            dwNumberOfBytesRecvd;
    //代表异步收好了 那边再来展示
    UINT nRetVal = WSARecv(pContext->m_Socket,
        &pContext->m_wsaInBuffer,
        1,
        &dwNumberOfBytesRecvd,
        &ulFlags,
        &pOverlap->m_ol,
        NULL);
    //WSAGetLastError() == WSAECONNRESET 应该是或者
    if (nRetVal == SOCKET_ERROR && WSAGetLastError() != WSA_IO_PENDING)
    {
        RemoveStaleClient(pContext, FALSE);
    }
}

/// 
/// 创建线程池绑定端口
/// 
/// 
/// 
bool CIOCPServer::InitializeIOCP()
{

    SOCKET s;
    DWORD i;
    UINT  nThreadID;
    SYSTEM_INFO systemInfo;

    //完成所有端口绑定
    m_hCompletionPort = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
    if (m_hCompletionPort == NULL)
    {
        return false;
    }

    //获取计算机信息
    GetSystemInfo(&systemInfo);

    m_nThreadPoolMin = systemInfo.dwNumberOfProcessors * 2;

    //开2*线程
    UINT nWorkerCnt = systemInfo.dwNumberOfProcessors * 2;

    HANDLE hWorker;
    m_nWorkerCnt = 0;
    //创建线程池
    for (i = 0; i < nWorkerCnt; i++)
    {
        hWorker = (HANDLE)_beginthreadex(NULL,
            0,
            ThreadPoolFunc, //轮询收数据     
            (void*)this,
            0,
            &nThreadID);


        if (hWorker == NULL)
        {
            CloseHandle(m_hCompletionPort);
            return false;
        }

        m_nWorkerCnt++;

        CloseHandle(hWorker);
    }

    return true;
}

/// 
/// 线程池工作线程
/// 
/// 
/// 
unsigned CIOCPServer::ThreadPoolFunc(LPVOID thisContext)
{

    ULONG ulFlags = MSG_PARTIAL;
    CIOCPServer* pThis = reinterpret_cast(thisContext);
    ASSERT(pThis);
    //完成端口
    HANDLE hCompletionPort = pThis->m_hCompletionPort;

    DWORD dwIoSize;
    LPOVERLAPPED lpOverlapped;
    ClientContext* lpClientContext;
    OVERLAPPEDPLUS* pOverlapPlus;
    bool            bError;
    bool            bEnterRead;

    InterlockedIncrement(&pThis->m_nCurrentThreads);//i++
    InterlockedIncrement(&pThis->m_nBusyThreads);

    for (BOOL bStayInPool = TRUE; bStayInPool && pThis->m_bTimeToKill == false; )
    {
        pOverlapPlus = NULL;
        lpClientContext = NULL;
        bError = false;
        bEnterRead = false;

        InterlockedDecrement(&pThis->m_nBusyThreads);//i--


        //检查完成端口状态
        BOOL bIORet = GetQueuedCompletionStatus(
            hCompletionPort,
            &dwIoSize,
            (LPDWORD)&lpClientContext,
            &lpOverlapped, INFINITE);

        DWORD dwIOError = GetLastError();
        pOverlapPlus = CONTAINING_RECORD(lpOverlapped, OVERLAPPEDPLUS, m_ol);

        //忙碌++
        int nBusyThreads = InterlockedIncrement(&pThis->m_nBusyThreads);
        //删除无用客户端
        if (!bIORet && dwIOError != WAIT_TIMEOUT)
        {
            if (lpClientContext && pThis->m_bTimeToKill == false)
            {
                pThis->RemoveStaleClient(lpClientContext, FALSE);
            }
            continue;

            // anyway, this was an error and we should exit
            bError = true;
        }

        if (!bError)
        {
            if (bIORet && NULL != pOverlapPlus && NULL != lpClientContext)
            {
                try
                {
                    if (pOverlapPlus->m_ioType == IORead)
                    {//处理客户端发来消息
                        pThis->OnClientReading(lpClientContext, dwIoSize);
                    }
                    if (pOverlapPlus->m_ioType == IOWrite)
                    {//向客户端发送消息
                        pThis->OnClientWriting(lpClientContext, dwIoSize);
                    }
                }
                catch (...) {}
            }
        }

        if (pOverlapPlus)
            delete pOverlapPlus; // from previous call
    }

    InterlockedDecrement(&pThis->m_nWorkerCnt);//减减

    InterlockedDecrement(&pThis->m_nCurrentThreads);//减减
    InterlockedDecrement(&pThis->m_nBusyThreads);//减减
    return 0;
}

/// 
/// 处理客户端发来消息
/// 
/// 发送来的数据结构
/// 数据包大小
/// 
bool CIOCPServer::OnClientReading(ClientContext* pContext, DWORD dwIoSize)
{
    //传送协议规则
    //第一个字节 数据包是否加密
    //int  4字节  数据包原大小
    //int  4字节  数据包压缩后大小

    CLock cs(CIOCPServer::m_cs);
    try
    {
        //如果为0就删除此客户端
        if (dwIoSize == 0)
        {
            RemoveStaleClient(pContext, FALSE);
            return false;
        }

        //接收数据  
        pContext->m_CompressionBuffer.Write(pContext->m_byInBuffer, dwIoSize);

        //数据包大小是否合法
        while (pContext->m_CompressionBuffer.m_csSize > HDR_SIZE)
        {
            int PacketLen = 0;//这里是原包大小
            int UnZipLen = 0;//这里是压缩后包大小
            int MaxPacketLen = 0;//这里是原包缓冲区大小
            CopyMemory(&PacketLen, pContext->m_CompressionBuffer.m_pBase + 1, sizeof(int));
            CopyMemory(&UnZipLen, pContext->m_CompressionBuffer.m_pBase + 5, sizeof(int));
            CopyMemory(&MaxPacketLen, pContext->m_CompressionBuffer.m_pBase + 9, sizeof(int));
            if (!PacketLen ||!UnZipLen || !MaxPacketLen)
            {
                return false;
            }
            //去除协议头
            char temp[HDR_SIZE] = {};
            pContext->m_CompressionBuffer.Read((PBYTE)temp, HDR_SIZE);

            PBYTE pDst = new BYTE[MaxPacketLen];

            //2 返回解压数据 3原大小 4压缩大小
            int    nRet = LZ4_decompress_safe((char*)pContext->m_CompressionBuffer.m_pBase, (char*)pDst, UnZipLen, PacketLen);
            if (nRet < 1)
            {
                return false;
            }
            pContext->m_DeCompressionBuffer.ClearBuffer();
            //赋值给解压缩后的大小
            pContext->m_DeCompressionBuffer.Write(pDst, PacketLen);
            delete[] pDst;
        }
    }
    catch (...)
    {
        pContext->m_CompressionBuffer.ClearBuffer();
        // 要求重发,就发送0, 内核自动添加数包标志
        Send(pContext, NULL, 0);
    }
    //接收包 每次接受完就要接收下一个包
    PostRecv(pContext);

    return true;
}

//向客户端发送消息
bool CIOCPServer::OnClientWriting(ClientContext* pContext, DWORD dwIoSize)
{
    try
    {
        ULONG ulFlags = MSG_PARTIAL;

        pContext->m_WriteBuffer.Delete(dwIoSize);
        if (pContext->m_WriteBuffer.m_csSize == 0)
        {
            pContext->m_WriteBuffer.ClearBuffer();
            SetEvent(pContext->m_hWriteComplete);
            return true;            
        }
        else
        {
            OVERLAPPEDPLUS* pOverlap = new OVERLAPPEDPLUS(IOWrite);

            pContext->m_wsaOutBuffer.buf = (char*)pContext->m_WriteBuffer.m_pBase;
            pContext->m_wsaOutBuffer.len = pContext->m_WriteBuffer.m_csSize;
            //发送客户端数据
            int nRetVal = WSASend(pContext->m_Socket,
                &pContext->m_wsaOutBuffer,
                1,
                &pContext->m_wsaOutBuffer.len,
                ulFlags,
                &pOverlap->m_ol,
                NULL);

            //失败就是删除客户端
            if (nRetVal == SOCKET_ERROR && WSAGetLastError() != WSA_IO_PENDING)
            {
                RemoveStaleClient(pContext, FALSE);
            }

        }
    }
    catch (...) {}
    return false;
}

//发送消息
void CIOCPServer::Send(ClientContext* pContext, LPBYTE lpData, UINT nSize)
{
    if (pContext == NULL)
        return;
    try
    {
        if (nSize > 0)
        {
            //压缩数据
            const size_t src_size = strlen((char*)lpData) + 1;
            const size_t max_dst_size = LZ4_compressBound(src_size);
            LPBYTE    pDest = new BYTE[max_dst_size];//分配压缩数据的空间
            int    nRet = LZ4_compress_default((char*)lpData, (char*)pDest, src_size, max_dst_size);
            if (nRet <1)
            {
                delete[] pDest;
                return;
            }

            LONG nBufLen = src_size + HDR_SIZE;//数据中加入数据头标识大小
            //是否加密
            pContext->m_WriteBuffer.Write(0, sizeof(char));
            //原数据包大小
            pContext->m_WriteBuffer.Write((PBYTE)&src_size, sizeof(src_size));
            //压缩后大小
            pContext->m_WriteBuffer.Write((PBYTE)&nRet, sizeof(nRet));
            //缓冲区大小
            pContext->m_WriteBuffer.Write((PBYTE)&max_dst_size, sizeof(max_dst_size));
            //数据包
            pContext->m_WriteBuffer.Write(pDest, sizeof(nRet));
            delete[] pDest;

            // 发送完后,再备份数据, 因为有可能是m_ResendWriteBuffer本身在发送,所以不直接写入
            LPBYTE lpResendWriteBuffer = new BYTE[nSize];
            CopyMemory(lpResendWriteBuffer, lpData, nSize);
            pContext->m_ResendWriteBuffer.ClearBuffer();
            pContext->m_ResendWriteBuffer.Write(lpResendWriteBuffer, nSize);    // 备份发送的数据
            delete[] lpResendWriteBuffer;
        }
        else // 要求重发
        {
            // 备份发送的数据    
        }
        WaitForSingleObject(pContext->m_hWriteComplete, INFINITE);

        OVERLAPPEDPLUS* pOverlap = new OVERLAPPEDPLUS(IOWrite);
        //让完成端口发数据
        PostQueuedCompletionStatus(m_hCompletionPort, 0, (DWORD)pContext, &pOverlap->m_ol);

        pContext->m_nMsgOut++;
    }
    catch (...) {}
}



// 删除客户端
void CIOCPServer::RemoveStaleClient(ClientContext* pContext, BOOL bGraceful)
{
    CLock cs(m_cs);
    LINGER lingerStruct;
    if (!bGraceful)
    {

        lingerStruct.l_onoff = 1;
        lingerStruct.l_linger = 0;
        //把数据发送完
        setsockopt(pContext->m_Socket, SOL_SOCKET, SO_LINGER,
            (char*)&lingerStruct, sizeof(lingerStruct));
    }



    //
    //防止当前客户端已经不存在 列表中
    if (m_listContexts.Find(pContext))
    {

        //
        // 发完消息删除socket
        CancelIo((HANDLE)pContext->m_Socket);

        closesocket(pContext->m_Socket);
        pContext->m_Socket = INVALID_SOCKET;
        //检查是否有没有返回的重叠I/O请求
        while (!HasOverlappedIoCompleted((LPOVERLAPPED)pContext))
            Sleep(0);

        MoveToFreePool(pContext);

    }
}
void CIOCPServer::MoveToFreePool(ClientContext* pContext)
{
    //清楚列表这个接连结构体
    CLock cs(m_cs);
    POSITION pos = m_listContexts.Find(pContext);
    if (pos)
    {
        pContext->m_CompressionBuffer.ClearBuffer();
        pContext->m_WriteBuffer.ClearBuffer();
        pContext->m_DeCompressionBuffer.ClearBuffer();
        pContext->m_ResendWriteBuffer.ClearBuffer();
        m_listFreePool.AddTail(pContext);
        m_listContexts.RemoveAt(pos);
    }
}

相关