4

我被要求修复位于两个应用程序之间的 Servlet。其目的是将SAML授权请求与SAML v2.0 / SAML 1.1相互转换。所以:

  • 从一个应用程序接收 HTTP SAML v2.0 授权请求
  • 将请求转换为 SAML v1.1
  • 将请求发送到第二个应用
  • 从第二个应用接收 SAML v1.1 响应
  • 将响应转换为 SAML v2.0
  • 将响应发送回第一个应用程序

不要担心 SAML 的东西,问题在于 HTTP 的东西。代码完成了它的工作,但它在负载下受到很大影响。我通过测试发现,即使代码使用了ThreadSafeClientConnManager来自 Apache httpcomponents 的代码,每个到达 servlet 的请求都以单线程方式处理。更准确地说,第二次代码到达HTTPClient.execute()方法时,第一个创建连接的线程将在任何其他线程开始工作之前运行整个过程的其余部分。例如:

  • 15 个请求同时命中 servlet
  • servlet 产生 15 个线程来服务请求
  • 所有 15 个线程都检索各自的请求数据
  • 所有 15 个线程将各自的数据从 SAML v2.0 转换为 SAML v1.1
  • 线程 1 调用HTTPClient.execute()
    • 线程 1 将请求发送到第二个应用程序
    • 线程 1 接收到来自第二个应用程序的响应
    • 线程 1 解码响应并将其从 SAML v1.1 转换为 SAML v2.0
    • 线程 1 将响应发送回第一个应用程序
  • 线程 2 调用HTTPClient.execute()
  • ... 等等 ...

我已经包含了下面的代码。从我可以看到所有必要的项目都存在。任何人都可以看到任何会阻止此 servlet 同时服务多个请求的错误或缺失吗?

public class MappingServlet extends HttpServlet {

private HttpClient client;
private String pdp_url;

public void init() throws ServletException {
    org.opensaml.Configuration.init();
    pdp_url = getInitParameter("pdp_url");

    ThreadSafeClientConnManager cm = new ThreadSafeClientConnManager();
    HttpRoute route = new HttpRoute(new HttpHost(pdp_url));
    cm.setDefaultMaxPerRoute(100);
    cm.setMaxForRoute(route, 100);
    cm.setMaxTotal(100);
    client = new DefaultHttpClient(cm);
}

protected void doPost(HttpServletRequest request, HttpServletResponse response)
    throws ServletException, IOException {

    long threadId = Thread.currentThread().getId();
    log.debug("[THREAD " + threadId + "] client request received");

    // Get the input entity (SAML2)
    InputStream in = null;
    byte[] query11 = null;
    try {
        in = request.getInputStream();
        query11 = Saml2Requester.convert(in);
        log.debug("[THREAD " + threadId + "] client request SAML11:\n" + query11);
    } catch (IOException ex) {
        log.error("[THREAD " + threadId + "]\n", ex);
        return;
    } finally {
        if (in != null) {
            try {
                in.close();
            } catch (IOException ioe) {
                log.error("[THREAD " + threadId + "]\n", ioe);
            }
        }
    }

    // Proxy the request to the PDP
    HttpPost httpPost = new HttpPost(pdp_url);
    ByteArrayEntity entity = new ByteArrayEntity(query11);
    httpPost.setEntity(entity);
    HttpResponse httpResponse = null;
    try {
        httpResponse = client.execute(httpPost);
    } catch (IOException ioe) {
        log.error("[THREAD " + threadId + "]\n", ioe);
        httpPost.abort();
        return;
    }

    int sc = httpResponse.getStatusLine().getStatusCode();
    if (sc != HttpStatus.SC_OK) {
        log.error("[THREAD " + threadId + "] Bad response from PDP: " + sc);
        httpPost.abort();
        return;
    }

    // Get the response back from the PDP
    InputStream in2 = null;
    byte[] resp = null;
    try {
        HttpEntity entity2 = httpResponse.getEntity();
        in2 = entity2.getContent();
        resp = Saml2Requester.consumeStream(in2);
        EntityUtils.consumeStream(in2);
        log.debug("[THREAD " + threadId + "] client response received, SAML11: " + resp);
    } catch (IOException ex) {
        log.error("[THREAD " + threadId + "]", ex);
        httpPost.abort();
        return;
    } finally {
        if (in2 != null) {
            try {
                in2.close();
            } catch (IOException ioe) {
                log.error("[THREAD " + threadId + "]", ioe);
            }
        }
    }

    // Convert the response from SAML1.1 to SAML2 and send back
    ByteArrayInputStream respStream = null;
    byte[] resp2 = null;
    try {
        respStream = new ByteArrayInputStream(resp);
        resp2 = Saml2Responder.convert(respStream);
    } finally {
        if (respStream != null) {
            try {
                respStream.close();
            } catch (IOException ioe) {
                log.error("[THREAD " + threadId + "]", ioe);
            }
        }
    }
    log.debug("[THREAD " + threadId + "] client response SAML2: " + resp2);

    OutputStream os2 = null;
    try {
        os2 = response.getOutputStream();
        os2.write(resp2.getBytes());
        log.debug("[THREAD " + threadId + "] client response forwarded");
    } catch (IOException ex) {
        log.error("[THREAD " + threadId + "]\n", ex);
        return;
    } finally {
        if (os2 != null) {
            try {
                os2.close();
            } catch (IOException ioe) {
                log.error("[THREAD " + threadId + "]\n", ioe);
            }
        }
    }
}

public void destroy() {
    client.getConnectionManager().shutdown();
    super.destroy();
}

}

提前致谢!

4

1 回答 1

4

直到被调用的HttpClient.execute()服务器发出所有的http 标头后才会返回。您的代码工作正常。我认为被调用的服务是真正的瓶颈。我为它创建了一个简单的概念验证代码(基于您的代码段):

import java.io.IOException;

import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.conn.routing.HttpRoute;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.impl.conn.tsccm.ThreadSafeClientConnManager;

public class MyHttpClient {

    private static final String url = "http://localhost:8080/WaitServlet";

    private final DefaultHttpClient client;

    public MyHttpClient() {
        final ThreadSafeClientConnManager cm = 
                new ThreadSafeClientConnManager();
        final HttpRoute route = new HttpRoute(new HttpHost(url));
        cm.setDefaultMaxPerRoute(100);
        cm.setMaxForRoute(route, 100);
        cm.setMaxTotal(100);
        client = new DefaultHttpClient(cm);
    }

    public void doPost() {
        final HttpPost httpPost = new HttpPost(url);

        HttpResponse httpResponse;
        try {
            httpResponse = client.execute(httpPost);
        } catch (final IOException ioe) {
            ioe.printStackTrace();
            httpPost.abort();
            return;
        }

        final StatusLine statusLine = httpResponse.getStatusLine();
        System.out.println("status: " + statusLine);
        final int statusCode = statusLine.getStatusCode();
        if (statusCode != HttpStatus.SC_OK) {
            httpPost.abort();
            return;
        }
    }
}

和一个测试:

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.junit.Test;

public class HttpClientTest {

    @Test
    public void test2() throws Exception {
        final ExecutorService executorService = 
                Executors.newFixedThreadPool(16);

        final MyHttpClient myHttpClient = new MyHttpClient();

        for (int i = 0; i < 8; i++) {
            final Runnable runnable = new Runnable() {

                @Override
                public void run() {
                    myHttpClient.doPost();
                }
            };
            executorService.execute(runnable);
        }

        executorService.shutdown();
        executorService.awaitTermination(150, TimeUnit.SECONDS);
    }
}

最后,被调用WaitServlet

import java.io.IOException;
import java.io.PrintWriter;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

public class WaitServlet extends HttpServlet {
    private static final long serialVersionUID = 1L;

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp)
            throws ServletException, IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp)
            throws ServletException, IOException {
        try {
            Thread.sleep(30 * 1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        final PrintWriter writer = resp.getWriter();
        writer.println("wait end");
    }
}
于 2011-09-29T20:56:38.127 回答