7

我有以下代码:

IEnumerable<KeyValuePair<T, double>> items =
    sequence.Select(item => new KeyValuePair<T, double>(item, weight(item)));
if (items.Any(pair => pair.Value<0))
    throw new ArgumentException("Item weights cannot be less than zero.");

double sum = items.Sum(pair => pair.Value);
foreach (KeyValuePair<T, double> pair in items) {...}

weight一个在哪里Func<T, double>

问题是我想weight尽可能少地被执行。这意味着它应该对每个项目最多执行一次。我可以通过将其保存到数组来实现这一点。但是,如果任何权重返回负值,我不想继续执行。

有什么方法可以在 LINQ 框架内轻松完成此任务?

4

6 回答 6

15

当然,这是完全可行的:

public static Func<A, double> ThrowIfNegative<A, double>(this Func<A, double> f)
{
    return a=>
    { 
      double r = f(a);  
      // if r is NaN then this will throw.
      if ( !(r >= 0.0) )
        throw new Exception(); 
      return r;
    };
}

public static Func<A, R> Memoize<A, R>(this Func<A, R> f)
{
    var d = new Dictionary<A, R>();
    return a=>
    {
        R r;
        if (!d.TryGetValue(a, out r))
        {
          r = f(a);
          d.Add(a, r);
        }
        return r;
    };
}

现在...

Func<T, double> weight = whatever;
weight = weight.ThrowIfNegative().Memoize();

你就完成了。

于 2012-04-20T23:05:28.653 回答
2

一种方法是将异常移动到weight函数中,或者至少模拟这样做,通过执行以下操作:

Func<T, double> weightWithCheck = i =>
    {
        double result = weight(i);
        if (result < 0)
        {
            throw new ArgumentException("Item weights cannot be less than zero.");
        }
        return result;
    };

IEnumerable<KeyValuePair<T, double>> items =
    sequence.Select(item => new KeyValuePair<T, double>(item, weightWithCheck(item)));

double sum = items.Sum(pair => pair.Value);

至此,如果有一个例外,你应该有它。但是,您确实必须枚举items,然后才能确保获得异常,但是一旦获得,您将不会weight再次调用。

于 2012-04-20T23:06:27.767 回答
0

两个答案都很好(在哪里抛出异常,并记住函数)。

但是您真正的问题是每次使用时都会评估您的 LINQ 表达式,除非您强制它评估并存储为 List (或类似的)。只需更改此:

sequence.Select(item => new KeyValuePair<T, double>(item, weight(item)));

对此:

sequence.Select(item => new KeyValuePair<T, double>(item, weight(item))).ToList();

于 2012-04-20T23:07:52.207 回答
0

您可以使用 foreach 循环来做到这一点。这是一种在一个语句中执行此操作的方法:

IEnumerable<KeyValuePair<T, double>> items = sequence
    .Select(item => new KeyValuePair<T, double>(item, weight(item)))
    .Select(kvp =>
    {
        if (kvp.Value < 0)
            throw new ArgumentException("Item weights cannot be less than zero.");
        else
            return kvp;
    }
    );
于 2012-04-20T23:13:24.840 回答
0

不, LINQ框架中已经没有任何东西可以执行此操作,但是您当然可以编写自己的方法并从 linq 查询中调用它们(正如许多人已经展示的那样)。

就个人而言,我要么ToList是第一个查询,要么是使用 Eric 的建议。

于 2012-04-20T23:20:12.733 回答
0

除了其他答案建议的功能性记忆之外,您还可以对整个数据序列使用记忆:

var items = sequence
    .Select(item => new KeyValuePair<T, double>(item, weight(item)))
    .Memoize();

Memoize()(注意上面表达式末尾的方法调用)

数据记忆的一个很好的特性是它代表了ToList()orToArray()方法的替代品。

功能齐全的实现非常复杂:

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;

static class MemoizationExtensions
{
    /// <summary>
    /// Memoize all elements of a sequence, e.g. ensure that every element of a sequence is retrieved only once.
    /// </summary>
    /// <remarks>
    /// The resulting sequence is not thread safe.
    /// </remarks>
    /// <typeparam name="T">The type of the elements of source.</typeparam>
    /// <param name="source">The source sequence.</param>
    /// <returns>The sequence that fully replicates the source with all elements being memoized.</returns>
    public static IEnumerable<T> Memoize<T>(this IEnumerable<T> source) => Memoize(source, false);

    /// <summary>
    /// Memoize all elements of a sequence, e.g. ensure that every element of a sequence is retrieved only once.
    /// </summary>
    /// <typeparam name="T">The type of the elements of source.</typeparam>
    /// <param name="source">The source sequence.</param>
    /// <param name="isThreadSafe">Indicates whether resulting sequence is thread safe.</param>
    /// <returns>The sequence that fully replicates the source with all elements being memoized.</returns>
    public static IEnumerable<T> Memoize<T>(this IEnumerable<T> source, bool isThreadSafe)
    {
        switch (source)
        {
            case null:
                return null;

            case CachedEnumerable<T> existingCachedEnumerable:
                if (!isThreadSafe || existingCachedEnumerable is ThreadSafeCachedEnumerable<T>)
                {
                    // The source is already memoized with compatible parameters.
                    return existingCachedEnumerable;
                }
                break;

            case IList<T> _:
            case IReadOnlyList<T> _:
            case string _:
                // Given source types are intrinsically memoized by their nature.
                return source;
        }

        if (isThreadSafe)
            return new ThreadSafeCachedEnumerable<T>(source);
        else
            return new CachedEnumerable<T>(source);
    }

    class CachedEnumerable<T> : IEnumerable<T>, IReadOnlyList<T>
    {
        public CachedEnumerable(IEnumerable<T> source)
        {
            _Source = source;
        }

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        IEnumerable<T> _Source;

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        IEnumerator<T> _SourceEnumerator;

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        protected readonly IList<T> Cache = new List<T>();

        public virtual int Count
        {
            get
            {
                while (_TryCacheElementNoLock()) ;
                return Cache.Count;
            }
        }

        bool _TryCacheElementNoLock()
        {
            if (_SourceEnumerator == null && _Source != null)
            {
                _SourceEnumerator = _Source.GetEnumerator();
                _Source = null;
            }

            if (_SourceEnumerator == null)
            {
                // Source enumerator already reached the end.
                return false;
            }
            else if (_SourceEnumerator.MoveNext())
            {
                Cache.Add(_SourceEnumerator.Current);
                return true;
            }
            else
            {
                // Source enumerator has reached the end, so it is no longer needed.
                _SourceEnumerator.Dispose();
                _SourceEnumerator = null;
                return false;
            }
        }

        public virtual T this[int index]
        {
            get
            {
                _EnsureItemIsCachedNoLock(index);
                return Cache[index];
            }
        }

        public IEnumerator<T> GetEnumerator() => new CachedEnumerator<T>(this);

        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

        internal virtual bool EnsureItemIsCached(int index) => _EnsureItemIsCachedNoLock(index);

        bool _EnsureItemIsCachedNoLock(int index)
        {
            while (Cache.Count <= index)
            {
                if (!_TryCacheElementNoLock())
                    return false;
            }
            return true;
        }

        internal virtual T GetCacheItem(int index) => Cache[index];
    }

    sealed class ThreadSafeCachedEnumerable<T> : CachedEnumerable<T>
    {
        public ThreadSafeCachedEnumerable(IEnumerable<T> source) :
            base(source)
        {
        }

        public override int Count
        {
            get
            {
                lock (Cache)
                    return base.Count;
            }
        }

        public override T this[int index]
        {
            get
            {
                lock (Cache)
                    return base[index];
            }
        }

        internal override bool EnsureItemIsCached(int index)
        {
            lock (Cache)
                return base.EnsureItemIsCached(index);
        }

        internal override T GetCacheItem(int index)
        {
            lock (Cache)
                return base.GetCacheItem(index);
        }
    }

    sealed class CachedEnumerator<T> : IEnumerator<T>
    {
        CachedEnumerable<T> _CachedEnumerable;

        const int InitialIndex = -1;
        const int EofIndex = -2;

        int _Index = InitialIndex;

        public CachedEnumerator(CachedEnumerable<T> cachedEnumerable)
        {
            _CachedEnumerable = cachedEnumerable;
        }

        public T Current
        {
            get
            {
                var cachedEnumerable = _CachedEnumerable;
                if (cachedEnumerable == null)
                    throw new InvalidOperationException();

                var index = _Index;
                if (index < 0)
                    throw new InvalidOperationException();

                return cachedEnumerable.GetCacheItem(index);
            }
        }

        object IEnumerator.Current => Current;

        public void Dispose()
        {
            _CachedEnumerable = null;
        }

        public bool MoveNext()
        {
            var cachedEnumerable = _CachedEnumerable;
            if (cachedEnumerable == null)
            {
                // Disposed.
                return false;
            }

            if (_Index == EofIndex)
                return false;

            _Index++;
            if (!cachedEnumerable.EnsureItemIsCached(_Index))
            {
                _Index = EofIndex;
                return false;
            }
            else
            {
                return true;
            }
        }

        public void Reset()
        {
            _Index = InitialIndex;
        }
    }
}

更多信息和现成的 NuGet 包:https ://github.com/gapotchenko/Gapotchenko.FX/tree/master/Source/Gapotchenko.FX.Linq#memoize

于 2019-04-03T21:36:42.337 回答