0

我正在尝试用 C# 编写一个小型、简洁的异步 TCP 套接字服务器。此示例只是将每个请求数据包回显给其连接的客户端。目前,我遇到了池阵列损坏的问题。由于数据包头接收缓冲区损坏,我的示例引发异常(第 27 行)。缓冲区中的值将是负数或巨大的(对于数据包长度而言,大约为数百兆字节)。这是由于缓冲区中位置 0-4 的字节被设置为随机值。

编辑:我忘了提到损坏只在运行一两分钟后出现(仅抛出异常)。此外,对于单个客户端似乎还可以,但对于更多的客户端,腐败的可能性更高。

EDIT2:无论出于何种原因,该错误都是零星的,并且不容易重现。我试图在多台计算机上运行代码,并且在两台计算机上都出现了问题(但也没有发生)。我已经完成了整个运行而没有错误,并且我在 1 分钟或 5 分钟内(或其他各种时间)出现了运行错误。我实现了一个信号量来尝试更轻松地控制活动客户端的数量,但它无助于显示错误。

internal class Program
{
    // consistently working
    //public const int ConcurrentClients = 8, ClientCount = 1_000, PacketCount = 1_000, MinPacketBytes = 8, MaxPacketBytes = 64;

    // not consistently working
    public const int ConcurrentClients = 8, ClientCount = 1000, PacketCount = 1_000, MinPacketBytes = 1024, MaxPacketBytes = 8192;

    private static readonly Encoding ServerEncoding = Encoding.UTF8;
    private static readonly IPEndPoint ServerEndPoint = new IPEndPoint(IPAddress.IPv6Loopback, 12347);
    private static readonly IPEndPoint ClientEndPoint = new IPEndPoint(IPAddress.IPv6Loopback, 0);

    private static readonly SemaphoreSlim ClientSemaphore = new SemaphoreSlim(ConcurrentClients);

    private static void ClientTask()
    {
        TcpClient client = new TcpClient();

        client.Bind(ClientEndPoint);
        client.Connect(ServerEndPoint);

        byte[] request = new byte[MaxPacketBytes];
        byte[] response = new byte[MaxPacketBytes];

        byte[] message = ServerEncoding.GetBytes("Hello, World!");

        for (int i = 0; i < PacketCount; i++)
        {
            message.CopyTo(request, 0);

            int written = client.Send(i, request);
            int read = client.Receive(response, out int id);
        }

        client.Disconnect();
        client.Dispose();

        ClientSemaphore.Release();
    }

    private static void ServerTask()
    {
        TcpServer server = new TcpServer();

        server.Bind(ServerEndPoint);
        server.Run(ClientCount);

        Console.WriteLine("Press enter to stop the server...");
        _ = Console.ReadLine();

        server.Stop();
        server.Dispose();

        Console.WriteLine("Press enter to exit...");
        _ = Console.ReadLine();
    }

    private static void Main(string[] args)
    {
        Task serverTask = Task.Factory.StartNew(ServerTask);

        for (int i = 0; i < ClientCount; i++)
        {
            ClientSemaphore.Wait();

            Task.Factory.StartNew(ClientTask);
        }

        Console.WriteLine("All clients completed successfully");

        serverTask.GetAwaiter().GetResult();
    }
}

我不确定我是否遇到了 Use-After-Free 错误、Double-Free 或其他问题。我已经三重检查了我的缓冲区租用和释放,但似乎没问题。到目前为止,我已将损坏范围缩小到仅在执行第 376 行之后发生(在将响应发送回客户端并且正在设置套接字参数以接收下一个请求数据包标头之后)。

异常堆栈跟踪是:

Unhandled exception. System.Exception: Exception of type 'System.Exception' was thrown.
   at NetworkingPlayground.Packet.ReadHeader(ReadOnlySpan`1 buffer, Int32& length, Int32& id) in Program.cs:line 27
   at NetworkingPlayground.TcpServer.HandleReceive(SocketAsyncEventArgs args) in Program.cs:line 215
   at NetworkingPlayground.TcpServer.HandleClient(SocketAsyncEventArgs args) in Program.cs:line 376
   at NetworkingPlayground.TcpServer.HandleAccept(SocketAsyncEventArgs args) in Program.cs:line 189

缓冲区寿命

单个请求-响应周期的缓冲区的预期生命周期如下:

Step:    Receive Request Header --> Receive Request Body --> Send Response Header + Body
Buf A:  [ Header Receive Buffer ]
Buf B:                            [ Body Receive Buffer ]
Buf C:                                              [ Response Send Buffer             ]
  • 应在租用主体接收缓冲区之前返回标头接收缓冲区(因为相关的标头信息已被复制出来:参见 HandleReceive,第 212-225 行)。

  • 主体接收缓冲区生存期应与响应发送缓冲区生存期重叠(因为必须将请求内容复制到响应主体缓冲区:请参阅 HandleClient,第 367-377 行)。

代码

GitHub 要点:https ://gist.github.com/mblenczewski/b0a302d443ef95e1d2eaefa6f3a18582

using System;
using System.Buffers;
using System.Diagnostics;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace NetworkingPlayground
{
    internal static class Packet
    {
        public const int HeaderSize = sizeof(int) + sizeof(int);

        public static int TotalSize(int dataSize)
        {
            return HeaderSize + dataSize;
        }

        public static void ReadHeader(ReadOnlySpan<byte> buffer, out int length, out int id)
        {
            length = MemoryMarshal.Read<int>(buffer.Slice(0, sizeof(int)));
            
            if (length < 0 || Program.MaxPacketBytes < length)
                throw new Exception();

            id = MemoryMarshal.Read<int>(buffer.Slice(sizeof(int), sizeof(int)));
        }

        public static void WriteHeader(Span<byte> buffer, int length, int id)
        {
            MemoryMarshal.Write(buffer.Slice(0, sizeof(int)), ref length);
            MemoryMarshal.Write(buffer.Slice(sizeof(int), sizeof(int)), ref id);
        }

        public static void Serialise(Span<byte> buffer, int id, ReadOnlySpan<byte> data, int length)
        {
            WriteHeader(buffer, length, id);

            data.Slice(0, length).CopyTo(buffer.Slice(HeaderSize));
        }

        public static void Deserialise(ReadOnlySpan<byte> buffer, Span<byte> dataBuffer, out int id)
        {
            ReadHeader(buffer, out int length, out id);

            buffer.Slice(HeaderSize, length).CopyTo(dataBuffer);
        }
    }

    internal class Program
    {
        public const int PacketCount = 1000, MinPacketBytes = 1024, MaxPacketBytes = 8192;
        public const int ClientCount = 1000;

        private static readonly Encoding ServerEncoding = Encoding.UTF8;
        private static readonly IPEndPoint ServerEndPoint = new IPEndPoint(IPAddress.IPv6Loopback, 12347);
        private static readonly IPEndPoint ClientEndPoint = new IPEndPoint(IPAddress.IPv6Loopback, 0);

        private static void ClientTask()
        {
            TcpClient client = new TcpClient();

            client.Bind(ClientEndPoint);
            client.Connect(ServerEndPoint);

            byte[] request = new byte[MaxPacketBytes];
            byte[] response = new byte[MaxPacketBytes];

            byte[] message = ServerEncoding.GetBytes("Hello, World!");

            for (int i = 0; i < PacketCount; i++)
            {
                message.CopyTo(request, 0);

                int written = client.Send(i, request);
                int read = client.Receive(response, out int id);
            }

            client.Disconnect();
            client.Dispose();
        }

        private static void ServerTask()
        {
            TcpServer server = new TcpServer();

            server.Bind(ServerEndPoint);
            server.Run(ClientCount);

            Console.WriteLine("Press enter to stop the server...");
            _ = Console.ReadLine();

            server.Stop();

            server.Dispose();
        }

        private static void Main(string[] args)
        {
            Task serverTask = Task.Factory.StartNew(ServerTask);

            for (int i = 0; i < ClientCount; i++)
            {
                Task.Factory.StartNew(ClientTask);

                Thread.Sleep(500);
            }

            serverTask.GetAwaiter().GetResult();
        }
    }

    internal class TcpServer : IDisposable
    {
        private readonly ArrayPool<byte> bufferPool;
        private readonly Socket socket;
        private bool disposed, active;

        private readonly ManualResetEventSlim serverShutdownEvent;

        private int clientCount;
        private readonly object clientCountLock = new object();

        public TcpServer()
        {
            bufferPool = ArrayPool<byte>.Shared;

            socket = new Socket(SocketType.Stream, ProtocolType.Tcp);

            serverShutdownEvent = new ManualResetEventSlim(false);
        }

        public void Bind(EndPoint localEndPoint)
        {
            socket.Bind(localEndPoint);
        }

        public void Run(int listenCount)
        {
            active = true;
            socket.Listen(listenCount);

            AcceptPendingClients();
        }

        public void Stop()
        {
            if (!active) return;

            serverShutdownEvent.Set();

            lock (clientCountLock)
            {
                while (clientCount > 0)
                {
                    _ = Monitor.Wait(clientCountLock);
                }
            }

            active = false;
        }

        private void AcceptPendingClients()
        {
            while (true)
            {
                SocketAsyncEventArgs args = new SocketAsyncEventArgs();
                args.Completed += HandleIoCompleted;

                if (socket.AcceptAsync(args))
                {
                    break;
                }

                HandleAccept(args);
            }
        }

        private void HandleAccept(SocketAsyncEventArgs args)
        {
            switch (args.SocketError)
            {
                case SocketError.Success:
                    AcceptPendingClients();

                    HandleClient(args);
                    break;

                default:
                    break;
            }
        }

        private void HandleReceive(SocketAsyncEventArgs args)
        {
            ClientState clientState = (ClientState)args.UserToken;
            int received = args.BytesTransferred;
            int previouslyReceived = args.Offset;
            int totalReceived = previouslyReceived + received;
            int expected = clientState.TransferSize;

            switch (args.SocketError)
            {
                case SocketError.Success:
                    Socket client = args.AcceptSocket;

                    if (expected == Packet.HeaderSize) // receiving header
                    {
                        if (totalReceived == expected)
                        {
                            // parse received header
                            Packet.ReadHeader(args.Buffer, out int dataLength, out int dataId);
                            clientState.RequestLength = dataLength;
                            clientState.RequestId = dataId;

                            // setup socket args to receive data
                            bufferPool.Return(args.Buffer, clearArray: true);
                            args.SetBuffer(bufferPool.Rent(dataLength), 0, dataLength);
                            clientState.TransferSize = dataLength;

                            if (!client.ReceiveAsync(args)) HandleReceive(args);
                        }
                        else if (totalReceived > 0 && totalReceived < expected)
                        {
                            args.SetBuffer(totalReceived, expected - totalReceived);
                            if (!client.ReceiveAsync(args)) HandleReceive(args);
                        }
                        else if (received == 0)
                        {
                            clientState.ClientShutdownEvent.Set();
                        }
                    }
                    else // receiving body
                    {
                        if (totalReceived == expected)
                        {
                            clientState.ReceivedRequestEvent.Set();
                        }
                        else if (totalReceived > 0 && totalReceived < expected)
                        {
                            args.SetBuffer(totalReceived, expected - totalReceived);
                            if (!client.ReceiveAsync(args)) HandleReceive(args);
                        }
                        else if (received == 0)
                        {
                            clientState.ClientShutdownEvent.Set();
                        }
                    }
                    break;

                default:
                    clientState.ClientShutdownEvent.Set();
                    break;
            }
        }

        private void HandleSend(SocketAsyncEventArgs args)
        {
            ClientState clientState = (ClientState)args.UserToken;
            int sent = args.BytesTransferred;
            int previouslySent = args.Offset;
            int totalSent = previouslySent + sent;
            int expected = clientState.TransferSize;

            switch (args.SocketError)
            {
                case SocketError.Success:
                    Socket client = args.AcceptSocket;

                    if (totalSent == expected)
                    {
                        clientState.SentResponseEvent.Set();
                    }
                    else if (totalSent > 0 && totalSent < expected)
                    {
                        args.SetBuffer(totalSent, expected - totalSent);
                        if (!client.SendAsync(args)) HandleSend(args);
                    }
                    else if (sent == 0)
                    {
                        clientState.ClientShutdownEvent.Set();
                    }
                    break;

                default:
                    clientState.ClientShutdownEvent.Set();
                    break;
            }
        }

        private void HandleIoCompleted(object sender, SocketAsyncEventArgs args)
        {
            switch (args.LastOperation)
            {
                case SocketAsyncOperation.Accept:
                    HandleAccept(args);
                    break;

                case SocketAsyncOperation.Receive:
                    HandleReceive(args);
                    break;

                case SocketAsyncOperation.Send:
                    HandleSend(args);
                    break;
            }
        }

        private void HandleClient(SocketAsyncEventArgs args)
        {
            lock (clientCountLock)
            {
                clientCount++;
            }

            Socket client = args.AcceptSocket;
            Console.WriteLine($"[Server] Client {client.RemoteEndPoint} opened connection");

            ClientState clientState = new ClientState();
            args.UserToken = clientState;

            WaitHandle[] handles = new WaitHandle[]
            {
                serverShutdownEvent.WaitHandle,
                clientState.ClientShutdownEvent.WaitHandle,
                clientState.ReceivedRequestEvent.WaitHandle,
                clientState.SentResponseEvent.WaitHandle
            };

            // setup socket args to receive initial client request header
            args.SetBuffer(bufferPool.Rent(Packet.HeaderSize), 0, Packet.HeaderSize);
            clientState.TransferSize = Packet.HeaderSize;

            // start an async receive
            if (!client.ReceiveAsync(args)) HandleReceive(args);

            while (true)
            {
                int completedHandleIdx = WaitHandle.WaitAny(handles);
                switch (completedHandleIdx)
                {
                    case 0: // server shutdown event
                        goto serverShutdown;

                    case 1: // client shutdown event
                        goto clientShutdown;

                    case 2: // client request received event
                        clientState.ReceivedRequestEvent.Reset();

                        // setup socket args to send reply
                        int responseBufferLength = Packet.TotalSize(clientState.RequestLength);
                        byte[] responseBuffer = bufferPool.Rent(responseBufferLength);
                        Packet.Serialise(responseBuffer, clientState.RequestId, args.Buffer, clientState.RequestLength);
                        bufferPool.Return(args.Buffer, clearArray: true);

                        args.SetBuffer(responseBuffer, 0, responseBufferLength);
                        clientState.TransferSize = responseBufferLength;

                        // start an async send
                        if (!client.SendAsync(args)) HandleSend(args);
                        break;

                    case 3: // client response sent event
                        clientState.SentResponseEvent.Reset();

                        // setup socket args to receive header
                        bufferPool.Return(args.Buffer, clearArray: true);
                        args.SetBuffer(bufferPool.Rent(Packet.HeaderSize), 0, Packet.HeaderSize);
                        clientState.TransferSize = Packet.HeaderSize;

                        // start an async receive
                        if (!client.ReceiveAsync(args)) HandleReceive(args);
                        break;

                    default:
                        Debug.Assert(false, "Unhandled wait handle index!");
                        break;
                }
            }

            clientShutdown:
            Console.WriteLine($"[Server] Client {client.RemoteEndPoint} closed connection");

            serverShutdown:
            // return any pooled objects
            bufferPool.Return(args.Buffer, clearArray: true);

            client.Disconnect(false);
            client.Shutdown(SocketShutdown.Both);
            client.Close();
            client.Dispose();

            lock (clientCountLock)
            {
                clientCount--;
                Monitor.Pulse(clientCountLock);
            }
        }

        protected virtual void Dispose(bool disposing)
        {
            if (!disposed)
            {
                if (disposing)
                {
                    if (active) Stop();

                    socket.Close();
                    socket.Dispose();
                }

                disposed = true;
            }
        }

        public void Dispose()
        {
            Dispose(disposing: true);
            GC.SuppressFinalize(this);
        }

        internal class ClientState
        {
            internal int TransferSize { get; set; }
            internal int RequestLength { get; set; }
            internal int RequestId { get; set; }
            internal ManualResetEventSlim ReceivedRequestEvent { get; set; } = new ManualResetEventSlim(false);
            internal ManualResetEventSlim SentResponseEvent { get; set; } = new ManualResetEventSlim(false);
            internal ManualResetEventSlim ClientShutdownEvent { get; set; } = new ManualResetEventSlim(false);
        }
    }

    internal class TcpClient : IDisposable
    {
        private readonly ArrayPool<byte> bufferPool;
        private readonly Socket socket;
        private bool disposed;

        public TcpClient()
        {
            bufferPool = ArrayPool<byte>.Shared;

            socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
        }

        public void Bind(EndPoint localEndPoint)
        {
            socket.Bind(localEndPoint);
        }

        public void Connect(EndPoint remoteEndPoint)
        {
            socket.Connect(remoteEndPoint);
        }

        public void Disconnect(bool reuseSocket = false)
        {
            socket.Disconnect(reuseSocket);
        }

        public int Send(int id, ReadOnlyMemory<byte> buffer)
        {
            byte[] scratch = bufferPool.Rent(Packet.HeaderSize + buffer.Length);

            Packet.Serialise(scratch, id, buffer.Span, buffer.Length);

            int written = socket.Send(scratch);

            bufferPool.Return(scratch, clearArray: true);

            return written;
        }

        public int Receive(Memory<byte> buffer, out int id)
        {
            byte[] scratch = bufferPool.Rent(Packet.HeaderSize + buffer.Length);

            int read = socket.Receive(scratch);

            Packet.Deserialise(scratch, buffer.Span, out id);

            bufferPool.Return(scratch, clearArray: true);

            return read;
        }

        protected virtual void Dispose(bool disposing)
        {
            if (!disposed)
            {
                if (disposing)
                {
                    socket.Shutdown(SocketShutdown.Both);
                    socket.Close();
                    socket.Dispose();
                }

                disposed = true;
            }
        }

        public void Dispose()
        {
            Dispose(disposing: true);
            GC.SuppressFinalize(this);
        }
    }
}
4

0 回答 0