2

我能够继续实现我的异步 udp 服务器。但是,我有两次出现此错误,因为我的变量数据的类型*mut u8不是Send

error: future cannot be sent between threads safely
 help: within `impl std::future::Future`, the trait `std::marker::Send` is not implemented for `*mut u8`
note: captured value is not `Send`

和代码(MRE):

use std::error::Error;
use std::time::Duration;
use std::env;
use tokio::net::UdpSocket;
use tokio::{sync::mpsc, task, time}; // 1.4.0
use std::alloc::{alloc, Layout};
use std::mem;
use std::mem::MaybeUninit;
use std::net::SocketAddr;

const UDP_HEADER: usize = 8;
const IP_HEADER: usize = 20;
const AG_HEADER: usize = 4;
const MAX_DATA_LENGTH: usize = (64 * 1024 - 1) - UDP_HEADER - IP_HEADER;
const MAX_CHUNK_SIZE: usize = MAX_DATA_LENGTH - AG_HEADER;
const MAX_DATAGRAM_SIZE: usize = 0x10000;

/// A wrapper for [ptr::copy_nonoverlapping] with different argument order (same as original memcpy)
unsafe fn memcpy(dst_ptr: *mut u8, src_ptr: *const u8, len: usize) {
    std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
}

// Different from https://doc.rust-lang.org/std/primitive.u32.html#method.next_power_of_two
// Returns the [exponent] from the smallest power of two greater than or equal to n.
const fn next_power_of_two_exponent(n: u32) -> u32 {
    return 32 - (n - 1).leading_zeros();
}

async fn run_server(socket: UdpSocket) {
    let mut missing_indexes: Vec<u16> = Vec::new();
    let mut peer_addr = MaybeUninit::<SocketAddr>::uninit();
    let mut data = std::ptr::null_mut(); // ptr for the file bytes
    let mut len: usize = 0; // total len of bytes that will be written
    let mut layout = MaybeUninit::<Layout>::uninit();
    let mut buf = [0u8; MAX_DATA_LENGTH];
    let mut start = false;
    let (debounce_tx, mut debounce_rx) = mpsc::channel::<(usize, SocketAddr)>(3300);
    let (network_tx, mut network_rx) = mpsc::channel::<(usize, SocketAddr)>(3300);

    loop {
        // Listen for events
        let debouncer = task::spawn(async move {
            let duration = Duration::from_millis(3300);

            loop {
                match time::timeout(duration, debounce_rx.recv()).await {
                    Ok(Some((size, peer))) => {
                        eprintln!("Network activity");
                    }
                    Ok(None) => {
                        if start == true {
                            eprintln!("Debounce finished");
                            break;
                        }
                    }
                    Err(_) => {
                        eprintln!("{:?} since network activity", duration);
                    }
                }
            }
        });
        // Listen for network activity
        let server = task::spawn({
            // async{
            let debounce_tx = debounce_tx.clone();
            async move {
                while let Some((size, peer)) = network_rx.recv().await {
                    // Received a new packet
                    debounce_tx.send((size, peer)).await.expect("Unable to talk to debounce");
                    eprintln!("Received a packet {} from: {}", size, peer);

                    let packet_index: u16 = (buf[0] as u16) << 8 | buf[1] as u16;

                    if start == false { // first bytes of a new file: initialization // TODO: ADD A MUTEX to prevent many initializations
                        start = true;
                        let chunks_cnt: u32 = (buf[2] as u32) << 8 | buf[3] as u32;
                        let n: usize = MAX_DATAGRAM_SIZE << next_power_of_two_exponent(chunks_cnt);
                        unsafe {
                            layout.as_mut_ptr().write(Layout::from_size_align_unchecked(n, mem::align_of::<u8>()));
                            
                            
                            // /!\  data has type `*mut u8` which is not `Send`
                            data = alloc(layout.assume_init());
                            
                            peer_addr.as_mut_ptr().write(peer);
                        }
                        let a: Vec<u16> = vec![0; chunks_cnt as usize]; //(0..chunks_cnt).map(|x| x as u16).collect(); // create a sorted vector with all the required indexes
                        missing_indexes = a;
                    }
                    missing_indexes[packet_index as usize] = 1;
                    unsafe {
                        let dst_ptr = data.offset((packet_index as usize * MAX_CHUNK_SIZE) as isize);
                        memcpy(dst_ptr, &buf[AG_HEADER], size - AG_HEADER);
                    };
                    println!("receiving packet {} from: {}", packet_index, peer);
                }
            }
        });

        // Prevent deadlocks
        drop(debounce_tx);

        match socket.recv_from(&mut buf).await {
            Ok((size, src)) => {
                network_tx.send((size, src)).await.expect("Unable to talk to network");
            }
            Err(e) => {
                eprintln!("couldn't recieve a datagram: {}", e);
            }
        }
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string());
    let socket = UdpSocket::bind(&addr).await?;
    println!("Listening on: {}", socket.local_addr()?);
    run_server(socket);
    Ok(())
}

由于我正在从同步代码转换为异步代码,因此我知道多个线程可能会写入数据,这可能就是我遇到此类错误的原因。但是我不知道我可以使用哪种语法来“克隆” mut ptr 并使其对每个线程都是唯一的(对于缓冲区也是如此)。

正如user4815162342所建议的那样,我认为最好的是

通过将指针包装在结构中并为 NewStruct {} 声明不安全的 impl Send 来使指针发送。

任何帮助都非常感谢!

PS:完整代码可以在我的github 仓库中找到

4

1 回答 1

1

精简版

感谢 user4815162342 的评论,我决定为 mut ptr 添加一个实现,以便能够将它与发送和同步一起使用,这使我能够解决这部分问题(还有其他问题,但超出了这个问题的范围):

pub struct FileBuffer {
     data: *mut u8
 }

 unsafe impl Send for FileBuffer {}
 unsafe impl Sync for FileBuffer {}

//let mut data = std::ptr::null_mut(); // ptr for the file bytes
let mut fileBuffer: FileBuffer = FileBuffer { data:  std::ptr::null_mut() };

长版

use std::error::Error;
use std::time::Duration;
use std::env;
use tokio::net::UdpSocket;
use tokio::{sync::mpsc, task, time}; // 1.4.0
use std::alloc::{alloc, Layout};
use std::mem;
use std::mem::MaybeUninit;
use std::net::SocketAddr;

const UDP_HEADER: usize = 8;
const IP_HEADER: usize = 20;
const AG_HEADER: usize = 4;
const MAX_DATA_LENGTH: usize = (64 * 1024 - 1) - UDP_HEADER - IP_HEADER;
const MAX_CHUNK_SIZE: usize = MAX_DATA_LENGTH - AG_HEADER;
const MAX_DATAGRAM_SIZE: usize = 0x10000;

/// A wrapper for [ptr::copy_nonoverlapping] with different argument order (same as original memcpy)
unsafe fn memcpy(dst_ptr: *mut u8, src_ptr: *const u8, len: usize) {
    std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
}

// Different from https://doc.rust-lang.org/std/primitive.u32.html#method.next_power_of_two
// Returns the [exponent] from the smallest power of two greater than or equal to n.
const fn next_power_of_two_exponent(n: u32) -> u32 {
    return 32 - (n - 1).leading_zeros();
}

 pub struct FileBuffer {
     data: *mut u8
 }

 unsafe impl Send for FileBuffer {}
 unsafe impl Sync for FileBuffer {}

async fn run_server(socket: UdpSocket) {
    let mut missing_indexes: Vec<u16> = Vec::new();
    let mut peer_addr = MaybeUninit::<SocketAddr>::uninit();
    //let mut data = std::ptr::null_mut(); // ptr for the file bytes
    let mut fileBuffer: FileBuffer = FileBuffer { data:  std::ptr::null_mut() };
    let mut len: usize = 0; // total len of bytes that will be written
    let mut layout = MaybeUninit::<Layout>::uninit();
    let mut buf = [0u8; MAX_DATA_LENGTH];
    let mut start = false;
    let (debounce_tx, mut debounce_rx) = mpsc::channel::<(usize, SocketAddr)>(3300);
    let (network_tx, mut network_rx) = mpsc::channel::<(usize, SocketAddr)>(3300);

    loop {
        // Listen for events
        let debouncer = task::spawn(async move {
            let duration = Duration::from_millis(3300);

            loop {
                match time::timeout(duration, debounce_rx.recv()).await {
                    Ok(Some((size, peer))) => {
                        eprintln!("Network activity");
                    }
                    Ok(None) => {
                        if start == true {
                            eprintln!("Debounce finished");
                            break;
                        }
                    }
                    Err(_) => {
                        eprintln!("{:?} since network activity", duration);
                    }
                }
            }
        });
        // Listen for network activity
        let server = task::spawn({
            // async{
            let debounce_tx = debounce_tx.clone();
            async move {
                while let Some((size, peer)) = network_rx.recv().await {
                    // Received a new packet
                    debounce_tx.send((size, peer)).await.expect("Unable to talk to debounce");
                    eprintln!("Received a packet {} from: {}", size, peer);

                    let packet_index: u16 = (buf[0] as u16) << 8 | buf[1] as u16;

                    if start == false { // first bytes of a new file: initialization // TODO: ADD A MUTEX to prevent many initializations
                        start = true;
                        let chunks_cnt: u32 = (buf[2] as u32) << 8 | buf[3] as u32;
                        let n: usize = MAX_DATAGRAM_SIZE << next_power_of_two_exponent(chunks_cnt);
                        unsafe {
                            layout.as_mut_ptr().write(Layout::from_size_align_unchecked(n, mem::align_of::<u8>()));

                            // /!\  data has type `*mut u8` which is not `Send`
                            fileBuffer.data = alloc(layout.assume_init());

                            peer_addr.as_mut_ptr().write(peer);
                        }
                        let a: Vec<u16> = vec![0; chunks_cnt as usize]; //(0..chunks_cnt).map(|x| x as u16).collect(); // create a sorted vector with all the required indexes
                        missing_indexes = a;
                    }
                    missing_indexes[packet_index as usize] = 1;
                    unsafe {
                        let dst_ptr = fileBuffer.data.offset((packet_index as usize * MAX_CHUNK_SIZE) as isize);
                        memcpy(dst_ptr, &buf[AG_HEADER], size - AG_HEADER);
                    };
                    println!("receiving packet {} from: {}", packet_index, peer);
                }
            }
        });

        // Prevent deadlocks
        drop(debounce_tx);

        match socket.recv_from(&mut buf).await {
            Ok((size, src)) => {
                network_tx.send((size, src)).await.expect("Unable to talk to network");
            }
            Err(e) => {
                eprintln!("couldn't recieve a datagram: {}", e);
            }
        }
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string());
    let socket = UdpSocket::bind(&addr).await?;
    println!("Listening on: {}", socket.local_addr()?);
    run_server(socket);
    Ok(())
}
于 2021-03-30T15:33:58.723 回答