更新:从 jetty-9.4.15.v20190215 起,Jetty 内置了对端口统一的支持;看到这个答案。
我们可以
这是可能的,我们已经做到了。这里的代码适用于 Jetty 8;我没有用 Jetty 9 测试过,但是这个答案有类似的 Jetty 9 代码。
顺便说一句,这被称为端口统一,而且很明显, Glassfish使用Grizzly很早就支持它。
大纲
基本思想是产生一个org.eclipse.jetty.server.Connector
可以提前查看客户端请求的第一个字节的实现。幸运的是,HTTP 和 HTTPS 都让客户端开始通信。对于 HTTPS(通常是 TLS/SSL),第一个字节将是0x16
(TLS) 或>= 0x80
(SSLv2)。对于 HTTP,第一个字节将是旧的可打印 7 位 ASCII。现在,根据第一个字节,Connector
将产生 SSL 连接或普通连接。
在这里的代码中,我们利用了 JettySslSelectChannelConnector
本身 extends的事实SelectChannelConnector
,并且有一个newPlainConnection()
方法(调用它的超类来产生一个非 SSL 连接)以及一个newConnection()
方法(产生一个 SSL 连接)。因此,我们的 newConnector
可以在观察到来自客户端的第一个字节后扩展SslSelectChannelConnector
并委托给其中一个方法。
不幸的是,我们需要在第一个字节可用AsyncConnection
之前创建一个实例。该实例的某些方法甚至可能在第一个字节可用之前被调用。因此,我们创建了一个LazyConnection implements AsyncConnection
它,它可以在以后确定它将委托给哪种连接,甚至在它知道之前将合理的默认响应返回给某些方法。
基于 NIO,我们Connector
将使用SocketChannel
. 幸运的是,我们可以扩展SocketChannel
创建一个ReadAheadSocketChannelWrapper
委托给“真实”SocketChannel
但可以检查和存储客户端消息的第一个字节。
一些细节
一个非常hacky的位。我们Connector
必须重写的方法之一是customize(Endpoint,Request)
. 如果我们最终得到一个基于 SSL 的,Endpoint
我们可以传递给我们的超类;否则超类将抛出一个ClassCastException
,但只有在传递给它的超类并在Request
. 所以我们传递给超类,但是当我们看到异常时撤消设置方案。
我们还覆盖isConfidential()
并isIntegral()
确保我们的 servlet 可以正确使用HttpServletRequest.isSecure()
来确定是否使用了 HTTP 或 HTTPS。
尝试从客户端读取第一个字节可能会抛出一个,但我们可能不得不在一个不期望IOException
出现的地方尝试这样做,在这种情况下,我们会保留异常并稍后抛出它。IOException
在 Java >= 7 和 Java 6 中扩展SocketChannel
看起来不同。在后一种情况下,只需注释掉 Java 6SocketChannel
没有的方法。
编码
public class PortUnificationSelectChannelConnector extends SslSelectChannelConnector {
public PortUnificationSelectChannelConnector() {
super();
}
public PortUnificationSelectChannelConnector(SslContextFactory sslContextFactory) {
super(sslContextFactory);
}
@Override
protected SelectChannelEndPoint newEndPoint(SocketChannel channel, SelectSet selectSet, SelectionKey key) throws IOException {
return super.newEndPoint(new ReadAheadSocketChannelWrapper(channel, 1), selectSet, key);
}
@Override
protected AsyncConnection newConnection(SocketChannel channel, AsyncEndPoint endPoint) {
return new LazyConnection((ReadAheadSocketChannelWrapper)channel, endPoint);
}
@Override
public void customize(EndPoint endpoint, Request request) throws IOException {
String scheme = request.getScheme();
try {
super.customize(endpoint, request);
} catch (ClassCastException e) {
request.setScheme(scheme);
}
}
@Override
public boolean isConfidential(Request request) {
if (request.getAttribute("javax.servlet.request.cipher_suite") != null) return true;
else return isForwarded() && request.getScheme().equalsIgnoreCase(HttpSchemes.HTTPS);
}
@Override
public boolean isIntegral(Request request) {
return isConfidential(request);
}
class LazyConnection implements AsyncConnection {
private final ReadAheadSocketChannelWrapper channel;
private final AsyncEndPoint endPoint;
private final long timestamp;
private AsyncConnection connection;
public LazyConnection(ReadAheadSocketChannelWrapper channel, AsyncEndPoint endPoint) {
this.channel = channel;
this.endPoint = endPoint;
this.timestamp = System.currentTimeMillis();
this.connection = determineNewConnection(channel, endPoint, false);
}
public Connection handle() throws IOException {
if (connection == null) {
connection = determineNewConnection(channel, endPoint, false);
channel.throwPendingException();
}
if (connection != null) return connection.handle();
else return this;
}
public long getTimeStamp() {
return timestamp;
}
public void onInputShutdown() throws IOException {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onInputShutdown();
}
public boolean isIdle() {
if (connection == null) connection = determineNewConnection(channel, endPoint, false);
if (connection != null) return connection.isIdle();
else return false;
}
public boolean isSuspended() {
if (connection == null) connection = determineNewConnection(channel, endPoint, false);
if (connection != null) return connection.isSuspended();
else return false;
}
public void onClose() {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onClose();
}
public void onIdleExpired(long l) {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onIdleExpired(l);
}
AsyncConnection determineNewConnection(ReadAheadSocketChannelWrapper channel, AsyncEndPoint endPoint, boolean force) {
byte[] bytes = channel.getBytes();
if ((bytes == null || bytes.length == 0) && !force) return null;
if (looksLikeSsl(bytes)) {
return PortUnificationSelectChannelConnector.super.newConnection(channel, endPoint);
} else {
return PortUnificationSelectChannelConnector.super.newPlainConnection(channel, endPoint);
}
}
// TLS first byte is 0x16
// SSLv2 first byte is >= 0x80
// HTTP is guaranteed many bytes of ASCII
private boolean looksLikeSsl(byte[] bytes) {
if (bytes == null || bytes.length == 0) return false; // force HTTP
byte b = bytes[0];
return b >= 0x7F || (b < 0x20 && b != '\n' && b != '\r' && b != '\t');
}
}
static class ReadAheadSocketChannelWrapper extends SocketChannel {
private final SocketChannel channel;
private final ByteBuffer start;
private byte[] bytes;
private IOException pendingException;
private int leftToRead;
public ReadAheadSocketChannelWrapper(SocketChannel channel, int readAheadLength) throws IOException {
super(channel.provider());
this.channel = channel;
start = ByteBuffer.allocate(readAheadLength);
leftToRead = readAheadLength;
readAhead();
}
public synchronized void readAhead() throws IOException {
if (leftToRead > 0) {
int n = channel.read(start);
if (n == -1) {
leftToRead = -1;
} else {
leftToRead -= n;
}
if (leftToRead <= 0) {
start.flip();
bytes = new byte[start.remaining()];
start.get(bytes);
start.rewind();
}
}
}
public byte[] getBytes() {
if (pendingException == null) {
try {
readAhead();
} catch (IOException e) {
pendingException = e;
}
}
return bytes;
}
public void throwPendingException() throws IOException {
if (pendingException != null) {
IOException e = pendingException;
pendingException = null;
throw e;
}
}
private int readFromStart(ByteBuffer dst) throws IOException {
int sr = start.remaining();
int dr = dst.remaining();
if (dr == 0) return 0;
int n = Math.min(dr, sr);
dst.put(bytes, start.position(), n);
start.position(start.position() + n);
return n;
}
public synchronized int read(ByteBuffer dst) throws IOException {
throwPendingException();
readAhead();
if (leftToRead > 0) return 0;
int sr = start.remaining();
if (sr > 0) {
int n = readFromStart(dst);
if (n < sr) return n;
}
return sr + channel.read(dst);
}
public synchronized long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
throwPendingException();
if (offset + length > dsts.length || length < 0 || offset < 0) {
throw new IndexOutOfBoundsException();
}
readAhead();
if (leftToRead > 0) return 0;
int sr = start.remaining();
int newOffset = offset;
if (sr > 0) {
int accum = 0;
for (; newOffset < offset + length; newOffset++) {
accum += readFromStart(dsts[newOffset]);
if (accum == sr) break;
}
if (accum < sr) return accum;
}
return sr + channel.read(dsts, newOffset, length - newOffset + offset);
}
public int hashCode() {
return channel.hashCode();
}
public boolean equals(Object obj) {
return channel.equals(obj);
}
public String toString() {
return channel.toString();
}
public Socket socket() {
return channel.socket();
}
public boolean isConnected() {
return channel.isConnected();
}
public boolean isConnectionPending() {
return channel.isConnectionPending();
}
public boolean connect(SocketAddress remote) throws IOException {
return channel.connect(remote);
}
public boolean finishConnect() throws IOException {
return channel.finishConnect();
}
public int write(ByteBuffer src) throws IOException {
return channel.write(src);
}
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
return channel.write(srcs, offset, length);
}
@Override
protected void implCloseSelectableChannel() throws IOException {
channel.close();
}
@Override
protected void implConfigureBlocking(boolean block) throws IOException {
channel.configureBlocking(block);
}
// public SocketAddress getLocalAddress() throws IOException {
// return channel.getLocalAddress();
// }
//
// public <T> T getOption(java.net.SocketOption<T> name) throws IOException {
// return channel.getOption(name);
// }
//
// public Set<java.net.SocketOption<?>> supportedOptions() {
// return channel.supportedOptions();
// }
//
// public SocketChannel bind(SocketAddress local) throws IOException {
// return channel.bind(local);
// }
//
// public SocketAddress getRemoteAddress() throws IOException {
// return channel.getRemoteAddress();
// }
//
// public <T> SocketChannel setOption(java.net.SocketOption<T> name, T value) throws IOException {
// return channel.setOption(name, value);
// }
//
// public SocketChannel shutdownInput() throws IOException {
// return channel.shutdownInput();
// }
//
// public SocketChannel shutdownOutput() throws IOException {
// return channel.shutdownOutput();
// }
}
}