我的解决方案适用于 Hibernate+Spring+MySQL 的非常常见的用例
与上述答案类似,我的解决方案基于Richard Kennar 博士的. 然而,由于 Hibernate 经常与 Spring 一起使用,我希望我的解决方案能够很好地与 Spring 以及使用 Hibernate 的标准方法一起工作。因此,我的解决方案使用线程本地和单例 bean 的组合来实现结果。从技术上讲,拦截器在为 SessionFactory 准备的每个 SQL 语句上调用,但它会跳过所有逻辑并且不初始化任何 ThreadLocal(s),除非它是专门设置为计算总行数的查询。
使用下面的类,您的 Spring 配置如下所示:
<bean id="foundRowCalculator" class="my.hibernate.classes.MySQLCalcFoundRowsInterceptor" />
<!-- p:sessionFactoryBeanName="mySessionFactory"/ -->
<bean id="mySessionFactory"
class="org.springframework.orm.hibernate3.annotation.AnnotationSessionFactoryBean"
p:dataSource-ref="dataSource"
p:packagesToScan="my.hibernate.classes"
p:entityInterceptor-ref="foundRowCalculator"/>
基本上,您必须声明拦截器 bean,然后在 SessionFactoryBean 的“entityInterceptor”属性中引用它。如果您的 Spring 上下文中存在多个 SessionFactory 并且您要引用的会话工厂不称为“sessionFactory”,则必须仅设置“sessionFactoryBeanName”。无法设置引用的原因是这会导致无法解析的 bean 之间的相互依赖关系。
使用包装器 bean 的结果:
package my.hibernate.classes;
public class PagedResponse<T> {
public final List<T> items;
public final int total;
public PagedResponse(List<T> items, int total) {
this.items = items;
this.total = total;
}
}
然后使用抽象的基础 DAO 类,您必须在进行查询之前调用“setCalcFoundRows(true)”,然后在 [在 finally 块中以确保它被调用]之后调用“reset()”:
package my.hibernate.classes;
import org.hibernate.Criteria;
import org.hibernate.Query;
import org.springframework.beans.factory.annotation.Autowired;
public abstract class BaseDAO {
@Autowired
private MySQLCalcFoundRowsInterceptor rowCounter;
public <T> PagedResponse<T> getPagedResponse(Criteria crit, int firstResult, int maxResults) {
rowCounter.setCalcFoundRows(true);
try {
@SuppressWarnings("unchecked")
return new PagedResponse<T>(
crit.
setFirstResult(firstResult).
setMaxResults(maxResults).
list(),
rowCounter.getFoundRows());
} finally {
rowCounter.reset();
}
}
public <T> PagedResponse<T> getPagedResponse(Query query, int firstResult, int maxResults) {
rowCounter.setCalcFoundRows(true);
try {
@SuppressWarnings("unchecked")
return new PagedResponse<T>(
query.
setFirstResult(firstResult).
setMaxResults(maxResults).
list(),
rowCounter.getFoundRows());
} finally {
rowCounter.reset();
}
}
}
然后是一个名为 MyEntity 的 @Entity 的具体 DAO 类示例,它具有字符串属性“prop”:
package my.hibernate.classes;
import org.hibernate.SessionFactory;
import org.hibernate.criterion.Restrictions
import org.springframework.beans.factory.annotation.Autowired;
public class MyEntityDAO extends BaseDAO {
@Autowired
private SessionFactory sessionFactory;
public PagedResponse<MyEntity> getPagedEntitiesWithPropertyValue(String propVal, int firstResult, int maxResults) {
return getPagedResponse(
sessionFactory.
getCurrentSession().
createCriteria(MyEntity.class).
add(Restrictions.eq("prop", propVal)),
firstResult,
maxResults);
}
}
最后是完成所有工作的拦截器类:
package my.hibernate.classes;
import java.io.IOException;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import org.hibernate.EmptyInterceptor;
import org.hibernate.HibernateException;
import org.hibernate.SessionFactory;
import org.hibernate.Transaction;
import org.hibernate.jdbc.Work;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
public class MySQLCalcFoundRowsInterceptor extends EmptyInterceptor implements BeanFactoryAware {
/**
*
*/
private static final long serialVersionUID = 2745492452467374139L;
//
// Private statics
//
private final static String SELECT_PREFIX = "select ";
private final static String CALC_FOUND_ROWS_HINT = "SQL_CALC_FOUND_ROWS ";
private final static String SELECT_FOUND_ROWS = "select FOUND_ROWS()";
//
// Private members
//
private SessionFactory sessionFactory;
private BeanFactory beanFactory;
private String sessionFactoryBeanName;
private ThreadLocal<Boolean> mCalcFoundRows = new ThreadLocal<Boolean>();
private ThreadLocal<Integer> mSQLStatementsPrepared = new ThreadLocal<Integer>() {
@Override
protected Integer initialValue() {
return Integer.valueOf(0);
}
};
private ThreadLocal<Integer> mFoundRows = new ThreadLocal<Integer>();
private void init() {
if (sessionFactory == null) {
if (sessionFactoryBeanName != null) {
sessionFactory = beanFactory.getBean(sessionFactoryBeanName, SessionFactory.class);
} else {
try {
sessionFactory = beanFactory.getBean("sessionFactory", SessionFactory.class);
} catch (RuntimeException exp) {
}
if (sessionFactory == null) {
sessionFactory = beanFactory.getBean(SessionFactory.class);
}
}
}
}
@Override
public String onPrepareStatement(String sql) {
if (mCalcFoundRows.get() == null || !mCalcFoundRows.get().booleanValue()) {
return sql;
}
switch (mSQLStatementsPrepared.get()) {
case 0: {
mSQLStatementsPrepared.set(mSQLStatementsPrepared.get() + 1);
// First time, prefix CALC_FOUND_ROWS_HINT
StringBuilder builder = new StringBuilder(sql);
int indexOf = builder.indexOf(SELECT_PREFIX);
if (indexOf == -1) {
throw new HibernateException("First SQL statement did not contain '" + SELECT_PREFIX + "'");
}
builder.insert(indexOf + SELECT_PREFIX.length(), CALC_FOUND_ROWS_HINT);
return builder.toString();
}
case 1: {
mSQLStatementsPrepared.set(mSQLStatementsPrepared.get() + 1);
// Before any secondary selects, capture FOUND_ROWS. If no secondary
// selects are
// ever executed, getFoundRows() will capture FOUND_ROWS
// just-in-time when called
// directly
captureFoundRows();
return sql;
}
default:
// Pass-through untouched
return sql;
}
}
public void reset() {
if (mCalcFoundRows.get() != null && mCalcFoundRows.get().booleanValue()) {
mSQLStatementsPrepared.remove();
mFoundRows.remove();
mCalcFoundRows.remove();
}
}
@Override
public void afterTransactionCompletion(Transaction tx) {
reset();
}
public void setCalcFoundRows(boolean calc) {
if (calc) {
mCalcFoundRows.set(Boolean.TRUE);
} else {
reset();
}
}
public int getFoundRows() {
if (mCalcFoundRows.get() == null || !mCalcFoundRows.get().booleanValue()) {
throw new IllegalStateException("Attempted to getFoundRows without first calling 'setCalcFoundRows'");
}
if (mFoundRows.get() == null) {
captureFoundRows();
}
return mFoundRows.get();
}
//
// Private methods
//
private void captureFoundRows() {
init();
// Sanity checks
if (mFoundRows.get() != null) {
throw new HibernateException("'" + SELECT_FOUND_ROWS + "' called more than once");
}
if (mSQLStatementsPrepared.get() < 1) {
throw new HibernateException("'" + SELECT_FOUND_ROWS + "' called before '" + SELECT_PREFIX + CALC_FOUND_ROWS_HINT + "'");
}
// Fetch the total number of rows
sessionFactory.getCurrentSession().doWork(new Work() {
@Override
public void execute(Connection connection) throws SQLException {
final Statement stmt = connection.createStatement();
ResultSet rs = null;
try {
rs = stmt.executeQuery(SELECT_FOUND_ROWS);
if (rs.next()) {
mFoundRows.set(rs.getInt(1));
} else {
mFoundRows.set(0);
}
} finally {
if (rs != null) {
rs.close();
}
try {
stmt.close();
} catch (RuntimeException exp) {
}
}
}
});
}
public void setSessionFactoryBeanName(String sessionFactoryBeanName) {
this.sessionFactoryBeanName = sessionFactoryBeanName;
}
@Override
public void setBeanFactory(BeanFactory arg0) throws BeansException {
this.beanFactory = arg0;
}
}