0

我正在使用 ConcurrentHashMap 来缓存我在 SocketChannel 上处理的任务。StreamTask 是一个 Runnable,用于在客户端服务器通信期间超过往返阈值时重新调度自身,因此如果时间过去,它将自己从缓存中删除。此外,StreamWriteTask 线程会将其放在缓存中,而 StreaReadTask 将尝试将其删除。

问题是当我调用“processingCache.put()”时,它并不总是添加到地图中。

public class ClientServerTest {

    private class StreamTask implements Runnable {
        private final String taskIdentifier;
        private byte[] data;
        private int scheduleAttempts = 1;
        private long startTime;
        private Runnable future;

        private static final long ROND_TRIP_THRESHOLD = 15000L;
        private static final int MAX_SCHEDULE_ATTEMPTS = 3;

        public StreamTask(String taskIdentifier, byte[] data) {
            super();
            this.taskIdentifier = taskIdentifier;
            this.data = data;
        }

        @Override
        public void run() {
            if (scheduleAttempts < MAX_SCHEDULE_ATTEMPTS) {
                StreamTask task = null;
                processingCacheLock.writeLock().lock(); 
                try{
                    task = processingCache.remove(taskIdentifier);
                }finally{
                    processingCacheLock.writeLock().unlock();
                }

                if (task == null) {
                    return;
                }

                scheduleStreamTask(task);
                scheduleAttempts++;
            } else {
                failedTasks.add(this);
            }

        }

        @Override
        public int hashCode() {
            return taskIdentifier == null ? 0 : super.hashCode();
        }

        @Override
        public boolean equals(Object obj) {
            if (obj == null) {
                return false;
            }

            if (!(obj instanceof StreamTask)) {
                return false;
            }
            StreamTask task = (StreamTask) obj;
            boolean equals = false;
            if (this.taskIdentifier != null
                    && this.taskIdentifier.equals(task.taskIdentifier)) {
                equals = true;
            }

            if (this.hashCode() == task.hashCode()) {
                equals = true;
            }

            return equals;
        }

    }

    private class StreamWriteTask implements Runnable {
        private ByteBuffer buffer;
        private SelectionKey key;

        private StreamWriteTask(ByteBuffer buffer, SelectionKey key) {
            this.buffer = buffer;
            this.key = key;
        }

        private byte[] getData() {
            byte[] data;
            if (key.attachment() != null) {
                data = (byte[]) key.attachment();
                System.out.println("StreamWriteTask continuation.....");
            } else {
                StreamTask task = getStreamTask();
                if (task == null) {
                    return null;
                }
                System.out.println("Processing New Task ~~~~~ "
                        + task.taskIdentifier);
                processingCacheLock.readLock().lock();
                try {
                    task = processingCache.put(task.taskIdentifier, task);
                    boolean cached = processingCache.containsKey(task.taskIdentifier);
                    System.out.println("Has task been cached? " + cached);
                } finally {
                    processingCacheLock.readLock().unlock();
                }

                task.startTime = System.currentTimeMillis();
                data = task.data;
            }

            return data;
        }

        @Override
        public void run() {
            byte[] data = getData();
            if (data != null) {
                SocketChannel sc = (SocketChannel) key.channel();
                buffer.clear();
                buffer.put(data);
                buffer.flip();
                int results = 0;
                while (buffer.hasRemaining()) {
                    try {
                        results = sc.write(buffer);
                    } catch (IOException e) {
                        // TODO Auto-generated catch block
                        e.printStackTrace();
                    }

                    if (results == 0) {
                        buffer.compact();
                        buffer.flip();
                        data = new byte[buffer.remaining()];
                        buffer.get(data);
                        key.interestOps(SelectionKey.OP_WRITE);
                        key.attach(data);
                        System.out
                                .println("Partial write to socket channel....");
                        selector.wakeup();
                        return;
                    }
                }
            }

            System.out
                    .println("Write to socket channel complete for client...");
            key.interestOps(SelectionKey.OP_READ);
            key.attach(null);
            returnBuffer(buffer);
            selector.wakeup();
        }

    }

    private class StreamReadTask implements Runnable {
        private ByteBuffer buffer;
        private SelectionKey key;

        private StreamReadTask(ByteBuffer buffer, SelectionKey key) {
            this.buffer = buffer;
            this.key = key;
        }

        @Override
        public void run() {
            long endTime = System.currentTimeMillis();
            SocketChannel sc = (SocketChannel) key.channel();
            buffer.clear();
            byte[] data = (byte[]) key.attachment();
            if (data != null) {
                buffer.put(data);
            }
            int count = 0;
            int readAttempts = 0;
            try {
                while ((count = sc.read(buffer)) > 0) {
                    readAttempts++;
                }
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }

            if (count == 0) {
                buffer.flip();
                data = new byte[buffer.limit()];
                buffer.get(data);
                String uuid = new String(data);
                System.out.println("Client Read - uuid ~~~~ " + uuid);
                boolean success = finalizeStreamTask(uuid, endTime);
                key.interestOps(SelectionKey.OP_WRITE);
                key.attach(null);
                System.out.println("Did task finalize correctly ~~~~ "
                        + success);
            }

            if (count == -1) {
                try {
                    sc.close();
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }

            returnBuffer(buffer);
            selector.wakeup();
        }

    }

    private class ClientWorker implements Runnable {

        @Override
        public void run() {
            try {

                while (selector.isOpen()) {
                    int count = selector.select(500);

                    if (count == 0) {
                        continue;
                    }

                    Iterator<SelectionKey> it = selector.selectedKeys()
                            .iterator();

                    while (it.hasNext()) {
                        final SelectionKey key = it.next();
                        it.remove();
                        if (!key.isValid()) {
                            continue;
                        }

                        if (key.isConnectable()) {
                            SocketChannel sc = (SocketChannel) key.channel();
                            if (!sc.finishConnect()) {
                                continue;
                            }
                            sc.register(selector, SelectionKey.OP_WRITE);
                        }

                        if (key.isReadable()) {
                            ByteBuffer buffer = borrowBuffer();
                            if (buffer != null) {
                                key.interestOps(0);
                                executor.execute(new StreamReadTask(buffer, key));
                            }
                        }
                        if (key.isWritable()) {
                            ByteBuffer buffer = borrowBuffer();
                            if (buffer != null) {
                                key.interestOps(0);
                                executor.execute(new StreamWriteTask(buffer,
                                        key));
                            }
                        }
                    }
                }
            } catch (IOException ex) {
                // Handle Exception
            }

        }
    }

    private class ServerWorker implements Runnable {
        @Override
        public void run() {
            try {
                Selector selector = Selector.open();
                ServerSocketChannel ssc = ServerSocketChannel.open();
                ServerSocket socket = ssc.socket();
                socket.bind(new InetSocketAddress(9001));
                ssc.configureBlocking(false);
                ssc.register(selector, SelectionKey.OP_ACCEPT);
                ByteBuffer buffer = ByteBuffer.allocateDirect(65535);
                DataHandler handler = new DataHandler();

                while (selector.isOpen()) {
                    int count = selector.select(500);

                    if (count == 0) {
                        continue;
                    }

                    Iterator<SelectionKey> it = selector.selectedKeys()
                            .iterator();

                    while (it.hasNext()) {
                        final SelectionKey key = it.next();
                        it.remove();
                        if (!key.isValid()) {
                            continue;
                        }

                        if (key.isAcceptable()) {
                            ssc = (ServerSocketChannel) key.channel();
                            SocketChannel sc = ssc.accept();
                            sc.configureBlocking(false);
                            sc.register(selector, SelectionKey.OP_READ);
                        }
                        if (key.isReadable()) {
                            handler.readSocket(buffer, key);
                        }
                        if (key.isWritable()) {
                            handler.writeToSocket(buffer, key);
                        }
                    }
                }

            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }

    }

    private class DataHandler {

        private JsonObject parseData(StringBuilder builder) {
            if (!builder.toString().endsWith("}")) {
                return null;
            }

            JsonParser parser = new JsonParser();
            JsonObject obj = (JsonObject) parser.parse(builder.toString());
            return obj;
        }

        private void readSocket(ByteBuffer buffer, SelectionKey key)
                throws IOException {
            SocketChannel sc = (SocketChannel) key.channel();
            buffer.clear();
            int count = Integer.MAX_VALUE;
            int readAttempts = 0;
            try {
                while ((count = sc.read(buffer)) > 0) {
                    readAttempts++;
                }
            } catch (IOException e) {
                e.printStackTrace();
            }

            if (count == 0) {
                buffer.flip();
                StringBuilder builder = key.attachment() instanceof StringBuilder ? (StringBuilder) key
                        .attachment() : new StringBuilder();
                Charset charset = Charset.forName("UTF-8");
                CharsetDecoder decoder = charset.newDecoder();
                decoder.onMalformedInput(CodingErrorAction.IGNORE);
                CharBuffer charBuffer = decoder.decode(buffer);
                String content = charBuffer.toString();
                charBuffer = null;
                builder.append(content);
                JsonObject obj = parseData(builder);
                if (obj == null) {
                    // System.out.println("Server processed partial read for task");
                    key.attach(builder);
                    key.interestOps(SelectionKey.OP_READ);
                } else {
                    JsonPrimitive uuid = obj.get("uuid").getAsJsonPrimitive();
                    System.out
                            .println("Server read complete for task  ~~~~~~~ "
                                    + uuid);
                    key.attach(uuid.toString().getBytes());
                    key.interestOps(SelectionKey.OP_WRITE);
                }
            }

            if (count == -1) {
                key.attach(null);
                sc.close();
            }
        }

        private void writeToSocket(ByteBuffer buffer, SelectionKey key)
                throws IOException {
            SocketChannel sc = (SocketChannel) key.channel();
            byte[] data = (byte[]) key.attachment();
            buffer.clear();
            buffer.put(data);
            buffer.flip();
            int writeAttempts = 0;
            while (buffer.hasRemaining()) {
                int results = sc.write(buffer);
                writeAttempts++;
                // System.out.println("Write Attempt #" + writeAttempts);
                if (results == 0) {
                    System.out.println("Server process partial write....");
                    buffer.compact();
                    buffer.flip();
                    data = new byte[buffer.remaining()];
                    buffer.get(data);
                    key.attach(data);
                    key.interestOps(SelectionKey.OP_WRITE);
                    return;
                }
            }

            System.out.println("Server write complete for task ~~~~~ "
                    + new String(data));
            key.interestOps(SelectionKey.OP_READ);
            key.attach(null);
        }
    }

    public ClientServerTest() throws IOException {
        selector = Selector.open();
        processingCache = new ConcurrentHashMap<String, StreamTask>(
                MAX_DATA_LOAD, 2);
        for (int index = 0; index < MAX_DATA_LOAD; index++) {
            JsonObject obj = new JsonObject();
            String uuid = UUID.randomUUID().toString();
            obj.addProperty("uuid", uuid);
            String data = RandomStringUtils.randomAlphanumeric(12800000);
            obj.addProperty("event", data);
            StreamTask task = new StreamTask(uuid, obj.toString().getBytes());
            taskQueue.add(task);
        }

        for (int index = 0; index < CLIENT_SOCKET_CONNECTIONS; index++) {
            ByteBuffer bf = ByteBuffer.allocate(2 << 23);
            bufferQueue.add(bf);
            SocketChannel sc = SocketChannel.open();
            sc.configureBlocking(false);
            sc.connect(new InetSocketAddress("127.0.0.1", 9001));
            sc.register(selector, SelectionKey.OP_CONNECT);
        }

        Thread serverWorker = new Thread(new ServerWorker());
        serverWorker.start();

        Thread clientWorker = new Thread(new ClientWorker());
        clientWorker.start();

    }

    private void start() {
        long startTime = System.currentTimeMillis();
        for (;;) {
            if (taskQueue.isEmpty() && processingCache.isEmpty()) {
                long endTime = System.currentTimeMillis();
                System.out.println("Overall Processing time ~~~~ "
                        + (endTime - startTime) + "ms");
                break;
            }
        }
    }

    private ByteBuffer borrowBuffer() {
        ByteBuffer buffer = null;

        try {
            buffer = bufferQueue.poll(5000L, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        return buffer;
    }

    private boolean returnBuffer(ByteBuffer buffer) {
        boolean success = true;
        try {
            buffer.clear();
            bufferQueue.offer(buffer, 5000L, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            // TODO Auto-generated catch block
            success = false;
            e.printStackTrace();
        }
        return success;
    }

    private StreamTask getStreamTask() {
        StreamTask task = null;
        taskQueueAddLock.lock();
        try {
            task = taskQueue.take();
        } catch (InterruptedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } finally {
            taskQueueAddLock.unlock();
        }

        return task;
    }

    private boolean scheduleStreamTask(StreamTask task) {
        boolean success = true;
        taskQueueRemoveLock.lock();
        try {
            taskQueue.offer(task, 5000L, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            success = false;
            e.printStackTrace();
        } finally {
            taskQueueRemoveLock.unlock();
        }

        return success;
    }

    private boolean finalizeStreamTask(String uuid, long endTime) {
        boolean success = true;
        StreamTask task;
        processingCacheLock.writeLock().lock();
        try {
            task = processingCache.remove(uuid);
            success = task != null;
        } finally {
            processingCacheLock.writeLock().unlock();
        }

        if (success) {
            success = executor.remove(task.future);
            executor.purge();
        }

        if (!success) {
            taskQueueAddLock.lock();
            taskQueueRemoveLock.lock();
            try {
                Iterator<StreamTask> it = taskQueue.iterator();
                while (it.hasNext()) {
                    task = it.next();
                    if (task.taskIdentifier == uuid) {
                        it.remove();
                        success = true;
                    }
                }
            } finally {
                taskQueueAddLock.unlock();
                taskQueueRemoveLock.unlock();
            }
            success = !taskQueue.contains(task);
        }

        System.out.println("Processing time ~~~~~~ "
                + (endTime - task.startTime) + "ms");
        return success;
    }

    /**
     * @param args
     */
    public static void main(String[] args) {
        try {
            ClientServerTest test = new ClientServerTest();
            test.start();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    private static final int CLIENT_SOCKET_CONNECTIONS = 1;
    private static final int MAX_DATA_LOAD = 2;

    private volatile ConcurrentHashMap<String, StreamTask> processingCache;
    private volatile LinkedBlockingQueue<StreamTask> taskQueue = new LinkedBlockingQueue<StreamTask>();
    private volatile ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(
            CLIENT_SOCKET_CONNECTIONS);
    private volatile LinkedBlockingQueue<ByteBuffer> bufferQueue = new LinkedBlockingQueue<ByteBuffer>();
    private volatile List<StreamTask> failedTasks = new ArrayList<StreamTask>();
    private volatile Selector selector;
    private final ReentrantLock taskQueueAddLock = new ReentrantLock();
    private final ReentrantLock taskQueueRemoveLock = new ReentrantLock();
    private final ReentrantReadWriteLock processingCacheLock = new ReentrantReadWriteLock();
}
4

1 回答 1

0

您的问题可能是对put()返回的内容的误解。在这一行之后:

task = processingCache.put(task.taskIdentifier, task);

task 等于存储在该键的映射中的先前值(如果有),否则为 null。task.taskIdentifier如果地图在该调用之前没有键,则put()返回 null 和下一行:

boolean cached = processingCache.containsKey(task.taskIdentifier);

会抛出一个NullPointerException.

来自ConcurrenMap#put javadoc(强调我的):

返回与 key 关联的先前值,如果没有 key 映射,则返回null

于 2012-06-06T08:32:23.897 回答