我正在使用 ConcurrentHashMap 来缓存我在 SocketChannel 上处理的任务。StreamTask 是一个 Runnable,用于在客户端服务器通信期间超过往返阈值时重新调度自身,因此如果时间过去,它将自己从缓存中删除。此外,StreamWriteTask 线程会将其放在缓存中,而 StreaReadTask 将尝试将其删除。
问题是当我调用“processingCache.put()”时,它并不总是添加到地图中。
public class ClientServerTest {
private class StreamTask implements Runnable {
private final String taskIdentifier;
private byte[] data;
private int scheduleAttempts = 1;
private long startTime;
private Runnable future;
private static final long ROND_TRIP_THRESHOLD = 15000L;
private static final int MAX_SCHEDULE_ATTEMPTS = 3;
public StreamTask(String taskIdentifier, byte[] data) {
super();
this.taskIdentifier = taskIdentifier;
this.data = data;
}
@Override
public void run() {
if (scheduleAttempts < MAX_SCHEDULE_ATTEMPTS) {
StreamTask task = null;
processingCacheLock.writeLock().lock();
try{
task = processingCache.remove(taskIdentifier);
}finally{
processingCacheLock.writeLock().unlock();
}
if (task == null) {
return;
}
scheduleStreamTask(task);
scheduleAttempts++;
} else {
failedTasks.add(this);
}
}
@Override
public int hashCode() {
return taskIdentifier == null ? 0 : super.hashCode();
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (!(obj instanceof StreamTask)) {
return false;
}
StreamTask task = (StreamTask) obj;
boolean equals = false;
if (this.taskIdentifier != null
&& this.taskIdentifier.equals(task.taskIdentifier)) {
equals = true;
}
if (this.hashCode() == task.hashCode()) {
equals = true;
}
return equals;
}
}
private class StreamWriteTask implements Runnable {
private ByteBuffer buffer;
private SelectionKey key;
private StreamWriteTask(ByteBuffer buffer, SelectionKey key) {
this.buffer = buffer;
this.key = key;
}
private byte[] getData() {
byte[] data;
if (key.attachment() != null) {
data = (byte[]) key.attachment();
System.out.println("StreamWriteTask continuation.....");
} else {
StreamTask task = getStreamTask();
if (task == null) {
return null;
}
System.out.println("Processing New Task ~~~~~ "
+ task.taskIdentifier);
processingCacheLock.readLock().lock();
try {
task = processingCache.put(task.taskIdentifier, task);
boolean cached = processingCache.containsKey(task.taskIdentifier);
System.out.println("Has task been cached? " + cached);
} finally {
processingCacheLock.readLock().unlock();
}
task.startTime = System.currentTimeMillis();
data = task.data;
}
return data;
}
@Override
public void run() {
byte[] data = getData();
if (data != null) {
SocketChannel sc = (SocketChannel) key.channel();
buffer.clear();
buffer.put(data);
buffer.flip();
int results = 0;
while (buffer.hasRemaining()) {
try {
results = sc.write(buffer);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
if (results == 0) {
buffer.compact();
buffer.flip();
data = new byte[buffer.remaining()];
buffer.get(data);
key.interestOps(SelectionKey.OP_WRITE);
key.attach(data);
System.out
.println("Partial write to socket channel....");
selector.wakeup();
return;
}
}
}
System.out
.println("Write to socket channel complete for client...");
key.interestOps(SelectionKey.OP_READ);
key.attach(null);
returnBuffer(buffer);
selector.wakeup();
}
}
private class StreamReadTask implements Runnable {
private ByteBuffer buffer;
private SelectionKey key;
private StreamReadTask(ByteBuffer buffer, SelectionKey key) {
this.buffer = buffer;
this.key = key;
}
@Override
public void run() {
long endTime = System.currentTimeMillis();
SocketChannel sc = (SocketChannel) key.channel();
buffer.clear();
byte[] data = (byte[]) key.attachment();
if (data != null) {
buffer.put(data);
}
int count = 0;
int readAttempts = 0;
try {
while ((count = sc.read(buffer)) > 0) {
readAttempts++;
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
if (count == 0) {
buffer.flip();
data = new byte[buffer.limit()];
buffer.get(data);
String uuid = new String(data);
System.out.println("Client Read - uuid ~~~~ " + uuid);
boolean success = finalizeStreamTask(uuid, endTime);
key.interestOps(SelectionKey.OP_WRITE);
key.attach(null);
System.out.println("Did task finalize correctly ~~~~ "
+ success);
}
if (count == -1) {
try {
sc.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
returnBuffer(buffer);
selector.wakeup();
}
}
private class ClientWorker implements Runnable {
@Override
public void run() {
try {
while (selector.isOpen()) {
int count = selector.select(500);
if (count == 0) {
continue;
}
Iterator<SelectionKey> it = selector.selectedKeys()
.iterator();
while (it.hasNext()) {
final SelectionKey key = it.next();
it.remove();
if (!key.isValid()) {
continue;
}
if (key.isConnectable()) {
SocketChannel sc = (SocketChannel) key.channel();
if (!sc.finishConnect()) {
continue;
}
sc.register(selector, SelectionKey.OP_WRITE);
}
if (key.isReadable()) {
ByteBuffer buffer = borrowBuffer();
if (buffer != null) {
key.interestOps(0);
executor.execute(new StreamReadTask(buffer, key));
}
}
if (key.isWritable()) {
ByteBuffer buffer = borrowBuffer();
if (buffer != null) {
key.interestOps(0);
executor.execute(new StreamWriteTask(buffer,
key));
}
}
}
}
} catch (IOException ex) {
// Handle Exception
}
}
}
private class ServerWorker implements Runnable {
@Override
public void run() {
try {
Selector selector = Selector.open();
ServerSocketChannel ssc = ServerSocketChannel.open();
ServerSocket socket = ssc.socket();
socket.bind(new InetSocketAddress(9001));
ssc.configureBlocking(false);
ssc.register(selector, SelectionKey.OP_ACCEPT);
ByteBuffer buffer = ByteBuffer.allocateDirect(65535);
DataHandler handler = new DataHandler();
while (selector.isOpen()) {
int count = selector.select(500);
if (count == 0) {
continue;
}
Iterator<SelectionKey> it = selector.selectedKeys()
.iterator();
while (it.hasNext()) {
final SelectionKey key = it.next();
it.remove();
if (!key.isValid()) {
continue;
}
if (key.isAcceptable()) {
ssc = (ServerSocketChannel) key.channel();
SocketChannel sc = ssc.accept();
sc.configureBlocking(false);
sc.register(selector, SelectionKey.OP_READ);
}
if (key.isReadable()) {
handler.readSocket(buffer, key);
}
if (key.isWritable()) {
handler.writeToSocket(buffer, key);
}
}
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
private class DataHandler {
private JsonObject parseData(StringBuilder builder) {
if (!builder.toString().endsWith("}")) {
return null;
}
JsonParser parser = new JsonParser();
JsonObject obj = (JsonObject) parser.parse(builder.toString());
return obj;
}
private void readSocket(ByteBuffer buffer, SelectionKey key)
throws IOException {
SocketChannel sc = (SocketChannel) key.channel();
buffer.clear();
int count = Integer.MAX_VALUE;
int readAttempts = 0;
try {
while ((count = sc.read(buffer)) > 0) {
readAttempts++;
}
} catch (IOException e) {
e.printStackTrace();
}
if (count == 0) {
buffer.flip();
StringBuilder builder = key.attachment() instanceof StringBuilder ? (StringBuilder) key
.attachment() : new StringBuilder();
Charset charset = Charset.forName("UTF-8");
CharsetDecoder decoder = charset.newDecoder();
decoder.onMalformedInput(CodingErrorAction.IGNORE);
CharBuffer charBuffer = decoder.decode(buffer);
String content = charBuffer.toString();
charBuffer = null;
builder.append(content);
JsonObject obj = parseData(builder);
if (obj == null) {
// System.out.println("Server processed partial read for task");
key.attach(builder);
key.interestOps(SelectionKey.OP_READ);
} else {
JsonPrimitive uuid = obj.get("uuid").getAsJsonPrimitive();
System.out
.println("Server read complete for task ~~~~~~~ "
+ uuid);
key.attach(uuid.toString().getBytes());
key.interestOps(SelectionKey.OP_WRITE);
}
}
if (count == -1) {
key.attach(null);
sc.close();
}
}
private void writeToSocket(ByteBuffer buffer, SelectionKey key)
throws IOException {
SocketChannel sc = (SocketChannel) key.channel();
byte[] data = (byte[]) key.attachment();
buffer.clear();
buffer.put(data);
buffer.flip();
int writeAttempts = 0;
while (buffer.hasRemaining()) {
int results = sc.write(buffer);
writeAttempts++;
// System.out.println("Write Attempt #" + writeAttempts);
if (results == 0) {
System.out.println("Server process partial write....");
buffer.compact();
buffer.flip();
data = new byte[buffer.remaining()];
buffer.get(data);
key.attach(data);
key.interestOps(SelectionKey.OP_WRITE);
return;
}
}
System.out.println("Server write complete for task ~~~~~ "
+ new String(data));
key.interestOps(SelectionKey.OP_READ);
key.attach(null);
}
}
public ClientServerTest() throws IOException {
selector = Selector.open();
processingCache = new ConcurrentHashMap<String, StreamTask>(
MAX_DATA_LOAD, 2);
for (int index = 0; index < MAX_DATA_LOAD; index++) {
JsonObject obj = new JsonObject();
String uuid = UUID.randomUUID().toString();
obj.addProperty("uuid", uuid);
String data = RandomStringUtils.randomAlphanumeric(12800000);
obj.addProperty("event", data);
StreamTask task = new StreamTask(uuid, obj.toString().getBytes());
taskQueue.add(task);
}
for (int index = 0; index < CLIENT_SOCKET_CONNECTIONS; index++) {
ByteBuffer bf = ByteBuffer.allocate(2 << 23);
bufferQueue.add(bf);
SocketChannel sc = SocketChannel.open();
sc.configureBlocking(false);
sc.connect(new InetSocketAddress("127.0.0.1", 9001));
sc.register(selector, SelectionKey.OP_CONNECT);
}
Thread serverWorker = new Thread(new ServerWorker());
serverWorker.start();
Thread clientWorker = new Thread(new ClientWorker());
clientWorker.start();
}
private void start() {
long startTime = System.currentTimeMillis();
for (;;) {
if (taskQueue.isEmpty() && processingCache.isEmpty()) {
long endTime = System.currentTimeMillis();
System.out.println("Overall Processing time ~~~~ "
+ (endTime - startTime) + "ms");
break;
}
}
}
private ByteBuffer borrowBuffer() {
ByteBuffer buffer = null;
try {
buffer = bufferQueue.poll(5000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return buffer;
}
private boolean returnBuffer(ByteBuffer buffer) {
boolean success = true;
try {
buffer.clear();
bufferQueue.offer(buffer, 5000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
success = false;
e.printStackTrace();
}
return success;
}
private StreamTask getStreamTask() {
StreamTask task = null;
taskQueueAddLock.lock();
try {
task = taskQueue.take();
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} finally {
taskQueueAddLock.unlock();
}
return task;
}
private boolean scheduleStreamTask(StreamTask task) {
boolean success = true;
taskQueueRemoveLock.lock();
try {
taskQueue.offer(task, 5000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
success = false;
e.printStackTrace();
} finally {
taskQueueRemoveLock.unlock();
}
return success;
}
private boolean finalizeStreamTask(String uuid, long endTime) {
boolean success = true;
StreamTask task;
processingCacheLock.writeLock().lock();
try {
task = processingCache.remove(uuid);
success = task != null;
} finally {
processingCacheLock.writeLock().unlock();
}
if (success) {
success = executor.remove(task.future);
executor.purge();
}
if (!success) {
taskQueueAddLock.lock();
taskQueueRemoveLock.lock();
try {
Iterator<StreamTask> it = taskQueue.iterator();
while (it.hasNext()) {
task = it.next();
if (task.taskIdentifier == uuid) {
it.remove();
success = true;
}
}
} finally {
taskQueueAddLock.unlock();
taskQueueRemoveLock.unlock();
}
success = !taskQueue.contains(task);
}
System.out.println("Processing time ~~~~~~ "
+ (endTime - task.startTime) + "ms");
return success;
}
/**
* @param args
*/
public static void main(String[] args) {
try {
ClientServerTest test = new ClientServerTest();
test.start();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
private static final int CLIENT_SOCKET_CONNECTIONS = 1;
private static final int MAX_DATA_LOAD = 2;
private volatile ConcurrentHashMap<String, StreamTask> processingCache;
private volatile LinkedBlockingQueue<StreamTask> taskQueue = new LinkedBlockingQueue<StreamTask>();
private volatile ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(
CLIENT_SOCKET_CONNECTIONS);
private volatile LinkedBlockingQueue<ByteBuffer> bufferQueue = new LinkedBlockingQueue<ByteBuffer>();
private volatile List<StreamTask> failedTasks = new ArrayList<StreamTask>();
private volatile Selector selector;
private final ReentrantLock taskQueueAddLock = new ReentrantLock();
private final ReentrantLock taskQueueRemoveLock = new ReentrantLock();
private final ReentrantReadWriteLock processingCacheLock = new ReentrantReadWriteLock();
}