Rust:axum学习笔记(7) websocket


接继续,今天来学习下如何用axum实现websocket,代码如下:

Cargo.toml添加依赖项

[package]
name = "websocket"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum =  {version = "0.4.3", features = ["headers","ws"] }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version="0.3", features = ["env-filter"] }
tower-http = { version = "0.2.0", features = ["fs", "trace"] }
headers = "0.3"

关键是axum的features中的ws,接下来是示例代码main.rs

use axum::{
    extract::{
        ws::{Message, WebSocket, WebSocketUpgrade},
        TypedHeader,
    },
    response::IntoResponse,
    routing::{get},
    Router,
};
use std::net::SocketAddr;
use tower_http::{
    trace::{DefaultMakeSpan, TraceLayer},
};

#[tokio::main]
async fn main() {
    if std::env::var_os("RUST_LOG").is_none() {
        std::env::set_var("RUST_LOG", "example_websockets=debug,tower_http=debug")
    }
    tracing_subscriber::fmt::init();
    

    let app = Router::new()
        .route("/", get(|| async { "Hello, World!" }))
        //绑定websocket路由
        .route("/ws", get(ws_handler))
        .layer(
            TraceLayer::new_for_http()
                .make_span_with(DefaultMakeSpan::default().include_headers(true)),
        );

    
    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    tracing::debug!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn ws_handler(
    ws: WebSocketUpgrade,
    user_agent: Option>,
) -> impl IntoResponse {
    if let Some(TypedHeader(user_agent)) = user_agent {
        println!("`{}` connected", user_agent.as_str());
    }

    ws.on_upgrade(handle_socket)
}

async fn handle_socket(mut socket: WebSocket) {
    if let Some(msg) = socket.recv().await {
        if let Ok(msg) = msg {
            println!("Client says: {:?}", msg);
            //客户端发什么,服务端就回什么(只是演示而已)
            if socket
                .send(Message::Text(format!("{:?}", msg)))
                .await
                .is_err()
            {
                println!("client disconnected");
                return;
            }
        } else {
            println!("client disconnected");
            return;
        }
    }
}

核心就是handle_socket这个function,这里我们只是简单的将收到的内容,原封不动的发回浏览器。

运行一下:

先浏览http://localhost:3000/ 然后F12打开Console控制台,输入下面几行js代码:

socket = new WebSocket('ws://localhost:3000/ws');

socket.addEventListener('message', function (event) {
    console.log('Message from server ', event.data);
});

socket.send('你好,RUST!');

就能看到服务端回过来的内容