看来解决方案是将布鲁诺的建议和保罗的解决方案结合起来。
Paulo 的解决方案允许我们使用委托自定义 SSLSocket 或 SSLServerSocket 的行为。
Bruno 的建议允许我们告诉默认 SSL 实现使用我们修改后的 SSLSocket 或 SSLServerSocket。
这是我所做的:
- 创建一个委托 ServerSocket 类 ( MyServerSocket )
- 创建一个委托 ServerSocketFactory 类 (MyServerSocketFactory)
- 创建一个委托 SocketFactory 类 (MySocketFactory)
- 创建一个委托 Socket 类 (MySocket)
- 创建 XorInputStream(在这里找到)
- 创建 XorOutputStream(在这里找到)
在服务器端:
// Initialisation as usual
...
sslSocketFactory = sslContext.getSocketFactory();
serverSocketFactory = ServerSocketFactory.getDefault();
serverSocketFactory = new MyServerSocketFactory(serverSocketFactory);
serverSocket = serverSocketFactory.createServerSocket(port);
...
Socket s = (Socket) serverSocket.accept();
sslSocket = (SSLSocket) sslSocketFactory.createSocket(s, null, s.getPort(), false);
sslSocket.setUseClientMode(false);
sslSocket.setEnabledCipherSuites(new String[]{"SSL_RSA_WITH_RC4_128_MD5"});
sslSocket.setNeedClientAuth(true);
...
在客户端:
Socket s = new MySocketFactory(SocketFactory.getDefault()).createSocket(host, port);
SSLSocket socket = (SSLSocket) factory.createSocket(s, host, port, false);
来源
public class MyServerSocket extends ServerSocket {
private ServerSocket baseSocket;
public MyServerSocket(ServerSocket baseSocket) throws IOException {
this.baseSocket = baseSocket;
}
@Override
public Socket accept() throws IOException {
return new MySocket(baseSocket.accept());
}
@Override
public void bind(SocketAddress endpoint) throws IOException {
baseSocket.bind(endpoint);
}
@Override
public void bind(SocketAddress endpoint, int backlog) throws IOException {
baseSocket.bind(endpoint, backlog);
}
@Override
public void close() throws IOException {
baseSocket.close();
}
@Override
public ServerSocketChannel getChannel() {
return baseSocket.getChannel();
}
@Override
public InetAddress getInetAddress() {
return baseSocket.getInetAddress();
}
@Override
public int getLocalPort() {
return baseSocket.getLocalPort();
}
@Override
public SocketAddress getLocalSocketAddress() {
return baseSocket.getLocalSocketAddress();
}
@Override
public synchronized int getReceiveBufferSize() throws SocketException {
return baseSocket.getReceiveBufferSize();
}
@Override
public boolean getReuseAddress() throws SocketException {
return baseSocket.getReuseAddress();
}
@Override
public synchronized int getSoTimeout() throws IOException {
return baseSocket.getSoTimeout();
}
@Override
public boolean isBound() {
return baseSocket.isBound();
}
@Override
public boolean isClosed() {
return baseSocket.isClosed();
}
@Override
public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
baseSocket.setPerformancePreferences(connectionTime, latency, bandwidth);
}
@Override
public synchronized void setReceiveBufferSize(int size) throws SocketException {
baseSocket.setReceiveBufferSize(size);
}
@Override
public void setReuseAddress(boolean on) throws SocketException {
baseSocket.setReuseAddress(on);
}
@Override
public synchronized void setSoTimeout(int timeout) throws SocketException {
baseSocket.setSoTimeout(timeout);
}
@Override
public String toString() {
return baseSocket.toString();
}
}
public class MyServerSocketFactory extends ServerSocketFactory {
private ServerSocketFactory baseFactory;
public MyServerSocketFactory(ServerSocketFactory baseFactory) {
this.baseFactory = baseFactory;
}
@Override
public ServerSocket createServerSocket(int i) throws IOException {
return new MyServerSocket(baseFactory.createServerSocket(i));
}
@Override
public ServerSocket createServerSocket(int i, int i1) throws IOException {
return new MyServerSocket(baseFactory.createServerSocket(i, i1));
}
@Override
public ServerSocket createServerSocket(int i, int i1, InetAddress ia) throws IOException {
return new MyServerSocket(baseFactory.createServerSocket(i, i1, ia));
}
}
public class MySocket extends Socket {
private Socket baseSocket;
public MySocket(Socket baseSocket) {
this.baseSocket = baseSocket;
}
private XorInputStream xorInputStream = null;
private XorOutputStream xorOutputStream = null;
private final byte pattern = (byte)0xAC;
@Override
public InputStream getInputStream() throws IOException {
if (xorInputStream == null)
{
xorInputStream = new XorInputStream(baseSocket.getInputStream(), pattern);
}
return xorInputStream;
}
@Override
public OutputStream getOutputStream() throws IOException {
if (xorOutputStream == null)
{
xorOutputStream = new XorOutputStream(baseSocket.getOutputStream(), pattern);
}
return xorOutputStream;
}
@Override
public void bind(SocketAddress bindpoint) throws IOException {
baseSocket.bind(bindpoint);
}
@Override
public synchronized void close() throws IOException {
baseSocket.close();
}
@Override
public void connect(SocketAddress endpoint) throws IOException {
baseSocket.connect(endpoint);
}
@Override
public void connect(SocketAddress endpoint, int timeout) throws IOException {
baseSocket.connect(endpoint, timeout);
}
@Override
public SocketChannel getChannel() {
return baseSocket.getChannel();
}
@Override
public InetAddress getInetAddress() {
return baseSocket.getInetAddress();
}
@Override
public boolean getKeepAlive() throws SocketException {
return baseSocket.getKeepAlive();
}
@Override
public InetAddress getLocalAddress() {
return baseSocket.getLocalAddress();
}
@Override
public int getLocalPort() {
return baseSocket.getLocalPort();
}
@Override
public SocketAddress getLocalSocketAddress() {
return baseSocket.getLocalSocketAddress();
}
@Override
public boolean getOOBInline() throws SocketException {
return baseSocket.getOOBInline();
}
@Override
public int getPort() {
return baseSocket.getPort();
}
@Override
public synchronized int getReceiveBufferSize() throws SocketException {
return baseSocket.getReceiveBufferSize();
}
@Override
public SocketAddress getRemoteSocketAddress() {
return baseSocket.getRemoteSocketAddress();
}
@Override
public boolean getReuseAddress() throws SocketException {
return baseSocket.getReuseAddress();
}
@Override
public synchronized int getSendBufferSize() throws SocketException {
return baseSocket.getSendBufferSize();
}
@Override
public int getSoLinger() throws SocketException {
return baseSocket.getSoLinger();
}
@Override
public synchronized int getSoTimeout() throws SocketException {
return baseSocket.getSoTimeout();
}
@Override
public boolean getTcpNoDelay() throws SocketException {
return baseSocket.getTcpNoDelay();
}
@Override
public int getTrafficClass() throws SocketException {
return baseSocket.getTrafficClass();
}
@Override
public boolean isBound() {
return baseSocket.isBound();
}
@Override
public boolean isClosed() {
return baseSocket.isClosed();
}
@Override
public boolean isConnected() {
return baseSocket.isConnected();
}
@Override
public boolean isInputShutdown() {
return baseSocket.isInputShutdown();
}
@Override
public boolean isOutputShutdown() {
return baseSocket.isOutputShutdown();
}
@Override
public void sendUrgentData(int data) throws IOException {
baseSocket.sendUrgentData(data);
}
@Override
public void setKeepAlive(boolean on) throws SocketException {
baseSocket.setKeepAlive(on);
}
@Override
public void setOOBInline(boolean on) throws SocketException {
baseSocket.setOOBInline(on);
}
@Override
public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
baseSocket.setPerformancePreferences(connectionTime, latency, bandwidth);
}
@Override
public synchronized void setReceiveBufferSize(int size) throws SocketException {
baseSocket.setReceiveBufferSize(size);
}
@Override
public void setReuseAddress(boolean on) throws SocketException {
baseSocket.setReuseAddress(on);
}
@Override
public synchronized void setSendBufferSize(int size) throws SocketException {
baseSocket.setSendBufferSize(size);
}
@Override
public void setSoLinger(boolean on, int linger) throws SocketException {
baseSocket.setSoLinger(on, linger);
}
@Override
public synchronized void setSoTimeout(int timeout) throws SocketException {
baseSocket.setSoTimeout(timeout);
}
@Override
public void setTcpNoDelay(boolean on) throws SocketException {
baseSocket.setTcpNoDelay(on);
}
@Override
public void setTrafficClass(int tc) throws SocketException {
baseSocket.setTrafficClass(tc);
}
@Override
public void shutdownInput() throws IOException {
baseSocket.shutdownInput();
}
@Override
public void shutdownOutput() throws IOException {
baseSocket.shutdownOutput();
}
@Override
public String toString() {
return baseSocket.toString();
}
}
public class MySocketFactory extends SocketFactory {
private SocketFactory baseFactory;
public MySocketFactory(SocketFactory baseFactory) {
this.baseFactory = baseFactory;
}
@Override
public Socket createSocket() throws IOException {
return baseFactory.createSocket();
}
@Override
public boolean equals(Object obj) {
return baseFactory.equals(obj);
}
@Override
public int hashCode() {
return baseFactory.hashCode();
}
@Override
public String toString() {
return baseFactory.toString();
}
@Override
public Socket createSocket(String string, int i) throws IOException, UnknownHostException {
return new MySocket(baseFactory.createSocket(string, i));
}
@Override
public Socket createSocket(String string, int i, InetAddress ia, int i1) throws IOException, UnknownHostException {
return baseFactory.createSocket(string, i, ia, i1);
}
@Override
public Socket createSocket(InetAddress ia, int i) throws IOException {
return baseFactory.createSocket(ia, i);
}
@Override
public Socket createSocket(InetAddress ia, int i, InetAddress ia1, int i1) throws IOException {
return baseFactory.createSocket(ia, i, ia1, i1);
}
}