经过大量研究,我终于找到了实现我想要的方法。它的要点是我在对象上下文中使用事件处理程序拦截物化实体,然后将我的自定义集合类注入到我可以找到的每个集合属性中(通过反射)。
最重要的部分是拦截“DbCollectionEntry”,该类负责加载相关的集合属性。通过在实体和 DbCollectionEntry 之间摆动自己,我可以完全控制何时以及如何加载的内容。唯一的缺点是这个 DbCollectionEntry 类几乎没有公共成员,这需要我使用反射来操作它。
这是我的自定义集合类,它实现了 ICollection 并包含对适当 DbCollectionEntry 的引用:
public class FilteredCollection <TEntity> : ICollection<TEntity> where TEntity : Entity
{
private readonly DbCollectionEntry _dbCollectionEntry;
private readonly Func<TEntity, Boolean> _compiledFilter;
private readonly Expression<Func<TEntity, Boolean>> _filter;
private ICollection<TEntity> _collection;
private int? _cachedCount;
public FilteredCollection(ICollection<TEntity> collection, DbCollectionEntry dbCollectionEntry)
{
_filter = entity => !entity.Deleted;
_dbCollectionEntry = dbCollectionEntry;
_compiledFilter = _filter.Compile();
_collection = collection != null ? collection.Where(_compiledFilter).ToList() : null;
}
private ICollection<TEntity> Entities
{
get
{
if (_dbCollectionEntry.IsLoaded == false && _collection == null)
{
IQueryable<TEntity> query = _dbCollectionEntry.Query().Cast<TEntity>().Where(_filter);
_dbCollectionEntry.CurrentValue = this;
_collection = query.ToList();
object internalCollectionEntry =
_dbCollectionEntry.GetType()
.GetField("_internalCollectionEntry", BindingFlags.NonPublic | BindingFlags.Instance)
.GetValue(_dbCollectionEntry);
object relatedEnd =
internalCollectionEntry.GetType()
.BaseType.GetField("_relatedEnd", BindingFlags.NonPublic | BindingFlags.Instance)
.GetValue(internalCollectionEntry);
relatedEnd.GetType()
.GetField("_isLoaded", BindingFlags.NonPublic | BindingFlags.Instance)
.SetValue(relatedEnd, true);
}
return _collection;
}
}
#region ICollection<T> Members
void ICollection<TEntity>.Add(TEntity item)
{
if(_compiledFilter(item))
Entities.Add(item);
}
void ICollection<TEntity>.Clear()
{
Entities.Clear();
}
Boolean ICollection<TEntity>.Contains(TEntity item)
{
return Entities.Contains(item);
}
void ICollection<TEntity>.CopyTo(TEntity[] array, Int32 arrayIndex)
{
Entities.CopyTo(array, arrayIndex);
}
Int32 ICollection<TEntity>.Count
{
get
{
if (_dbCollectionEntry.IsLoaded)
return _collection.Count;
return _dbCollectionEntry.Query().Cast<TEntity>().Count(_filter);
}
}
Boolean ICollection<TEntity>.IsReadOnly
{
get
{
return Entities.IsReadOnly;
}
}
Boolean ICollection<TEntity>.Remove(TEntity item)
{
return Entities.Remove(item);
}
#endregion
#region IEnumerable<T> Members
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return Entities.GetEnumerator();
}
#endregion
#region IEnumerable Members
IEnumerator IEnumerable.GetEnumerator()
{
return ( ( this as IEnumerable<TEntity> ).GetEnumerator() );
}
#endregion
}
如果您浏览它,您会发现最重要的部分是“实体”属性,它会延迟加载实际值。在 FilteredCollection 的构造函数中,我为已经急切加载集合的场景传递了一个可选的 ICollection。
当然,我们仍然需要配置实体框架,以便我们的 FilteredCollection 在任何有集合属性的地方都可以使用。这可以通过挂钩实体框架底层 ObjectContext 的 ObjectMaterialized 事件来实现:
(this as IObjectContextAdapter).ObjectContext.ObjectMaterialized +=
delegate(Object sender, ObjectMaterializedEventArgs e)
{
if (e.Entity is Entity)
{
var entityType = e.Entity.GetType();
IEnumerable<PropertyInfo> collectionProperties;
if (!CollectionPropertiesPerType.TryGetValue(entityType, out collectionProperties))
{
CollectionPropertiesPerType[entityType] = (collectionProperties = entityType.GetProperties()
.Where(p => p.PropertyType.IsGenericType && typeof(ICollection<>) == p.PropertyType.GetGenericTypeDefinition()));
}
foreach (var collectionProperty in collectionProperties)
{
var collectionType = typeof(FilteredCollection<>).MakeGenericType(collectionProperty.PropertyType.GetGenericArguments());
DbCollectionEntry dbCollectionEntry = Entry(e.Entity).Collection(collectionProperty.Name);
dbCollectionEntry.CurrentValue = Activator.CreateInstance(collectionType, new[] { dbCollectionEntry.CurrentValue, dbCollectionEntry });
}
}
};
这一切看起来相当复杂,但它本质上是扫描物化类型的集合属性并将值更改为过滤集合。它还将 DbCollectionEntry 传递给过滤后的集合,以便发挥它的魔力。
这涵盖了整个“加载实体”部分。到目前为止唯一的缺点是急切加载的集合属性仍将包含已删除的实体,但它们在 FilterCollection 类的“添加”方法中被过滤掉了。这是一个可以接受的缺点,尽管我还没有对这如何影响 SaveChanges() 方法进行一些测试。
当然,这仍然留下一个问题:查询没有自动过滤。如果您想获取在过去一周进行过锻炼的健身房成员,您希望自动排除已删除的锻炼。
这是通过 ExpressionVisitor 实现的,该表达式自动将 '.Where(e => !e.Deleted)' 过滤器应用于它可以在给定表达式中找到的每个 IQueryable。
这是代码:
public class DeletedFilterInterceptor: ExpressionVisitor
{
public Expression<Func<Entity, bool>> Filter { get; set; }
public DeletedFilterInterceptor()
{
Filter = entity => !entity.Deleted;
}
protected override Expression VisitMember(MemberExpression ex)
{
return !ex.Type.IsGenericType ? base.VisitMember(ex) : CreateWhereExpression(Filter, ex) ?? base.VisitMember(ex);
}
private Expression CreateWhereExpression(Expression<Func<Entity, bool>> filter, Expression ex)
{
var type = ex.Type;//.GetGenericArguments().First();
var test = CreateExpression(filter, type);
if (test == null)
return null;
var listType = typeof(IQueryable<>).MakeGenericType(type);
return Expression.Convert(Expression.Call(typeof(Enumerable), "Where", new Type[] { type }, (Expression)ex, test), listType);
}
private LambdaExpression CreateExpression(Expression<Func<Entity, bool>> condition, Type type)
{
var lambda = (LambdaExpression) condition;
if (!typeof(Entity).IsAssignableFrom(type))
return null;
var newParams = new[] { Expression.Parameter(type, "entity") };
var paramMap = lambda.Parameters.Select((original, i) => new { original, replacement = newParams[i] }).ToDictionary(p => p.original, p => p.replacement);
var fixedBody = ParameterRebinder.ReplaceParameters(paramMap, lambda.Body);
lambda = Expression.Lambda(fixedBody, newParams);
return lambda;
}
}
public class ParameterRebinder : ExpressionVisitor
{
private readonly Dictionary<ParameterExpression, ParameterExpression> _map;
public ParameterRebinder(Dictionary<ParameterExpression, ParameterExpression> map)
{
_map = map ?? new Dictionary<ParameterExpression, ParameterExpression>();
}
public static Expression ReplaceParameters(Dictionary<ParameterExpression, ParameterExpression> map, Expression exp)
{
return new ParameterRebinder(map).Visit(exp);
}
protected override Expression VisitParameter(ParameterExpression node)
{
ParameterExpression replacement;
if (_map.TryGetValue(node, out replacement))
node = replacement;
return base.VisitParameter(node);
}
}
我的时间有点短,所以我稍后会回到这篇文章,提供更多细节,但它的要点已经写下来,适合那些渴望尝试一切的人;我在这里发布了完整的测试应用程序:https ://github.com/amoerie/TestingGround
但是,可能仍然存在一些错误,因为这是一项正在进行的工作。不过,这个概念性的想法是合理的,我希望它能够在我整齐地重构所有内容并找到时间为此编写一些测试后很快就能完全发挥作用。