1

我得到了一个像增强的Task.WhenAll. 它需要一堆任务并在全部完成后返回。

public async Task MyWhenAll(Task[] tasks) {
    ...
    await Something();
    ...

    // all tasks are completed
    if (someTasksFailed)
        throw ??
}

我的问题是如何让方法返回一个看起来像Task.WhenAll一个或多个任务失败时返回的任务?

如果我收集异常并抛出异常,AggregateException它将被包装在另一个 AggregateException 中。

编辑:完整示例

async Task Main() {
    try {
        Task.WhenAll(Throw(1), Throw(2)).Wait();
    }
    catch (Exception ex) {
        ex.Dump();
    }

    try {
        MyWhenAll(Throw(1), Throw(2)).Wait();
    }
    catch (Exception ex) {
        ex.Dump();
    }
}

public async Task MyWhenAll(Task t1, Task t2) {
    await Task.Delay(TimeSpan.FromMilliseconds(100));
    try {
        await Task.WhenAll(t1, t2);
    }
    catch {
        throw new AggregateException(new[] { t1.Exception, t2.Exception });
    }
}
public async Task Throw(int id) {
    await Task.Delay(TimeSpan.FromMilliseconds(100));
    throw new InvalidOperationException("Inner" + id);
}

因为Task.WhenAll异常AggregateException有 2 个内部异常。

因为MyWhenAll例外是AggregateException一个内部AggregateException有 2 个内部例外。

编辑:我为什么要这样做

我经常需要调用分页 API:s 并希望限制同时连接的数量。

实际的方法签名是

public static async Task<TResult[]> AsParallelAsync<TResult>(this IEnumerable<Task<TResult>> source, int maxParallel)
public static async Task<TResult[]> AsParallelUntilAsync<TResult>(this IEnumerable<Task<TResult>> source, int maxParallel, Func<Task<TResult>, bool> predicate)

这意味着我可以像这样进行分页

var pagedRecords = await Enumerable.Range(1, int.MaxValue)
                                   .Select(x => GetRecordsAsync(pageSize: 1000, pageNumber: x)
                                   .AsParallelUntilAsync(maxParallel: 5, x => x.Result.Count < 1000);
var records = pagedRecords.SelectMany(x => x).ToList();

一切正常,聚合中的聚合只是一个小不便。

4

3 回答 3

3

async方法被设计为每个返回的任务最多只设置一个异常,而不是多个。

这给您留下了两个选择,您可以不使用async方法开始,而是依靠其他方式来执行您的方法:

public Task MyWhenAll(Task t1, Task t2)
{
    return Task.Delay(TimeSpan.FromMilliseconds(100))
        .ContinueWith(_ => Task.WhenAll(t1, t2))
        .Unwrap();
}

如果您有一个更复杂的方法,如果不使用会更难编写await,那么您需要解开嵌套的聚合异常,虽然不是过于复杂,但这样做很乏味:

    public static Task UnwrapAggregateException(this Task taskToUnwrap)
    {
        var tcs = new TaskCompletionSource<bool>();

        taskToUnwrap.ContinueWith(task =>
        {
            if (task.IsCanceled)
                tcs.SetCanceled();
            else if (task.IsFaulted)
            {
                if (task.Exception is AggregateException aggregateException)
                    tcs.SetException(Flatten(aggregateException));
                else
                    tcs.SetException(task.Exception);
            }
            else //successful
                tcs.SetResult(true);
        });

        IEnumerable<Exception> Flatten(AggregateException exception)
        {
            var stack = new Stack<AggregateException>();
            stack.Push(exception);
            while (stack.Any())
            {
                var next = stack.Pop();
                foreach (Exception inner in next.InnerExceptions)
                {
                    if (inner is AggregateException innerAggregate)
                        stack.Push(innerAggregate);
                    else
                        yield return inner;
                }
            }
        }

        return tcs.Task;
    }
于 2019-04-15T15:38:49.447 回答
0

使用TaskCompletionSource.

最外层的异常由.Wait()or创建.Result- 这被记录为将存储在 Task 中的异常包装在 an 内AggregateException(以保留其堆栈跟踪 - 这是在ExceptionDispatchInfo创建之前引入的)。

但是,Task 实际上可以包含许多异常。在这种情况下,.Wait().Result抛出一个AggregateException包含多个InnerExceptions. 您可以通过 访问此功能TaskCompletionSource.SetException(IEnumerable<Exception> exceptions)

所以你不想创建你自己的AggregateException. 在任务上设置多个例外,并为您创建.Wait().Result创建AggregateException

所以:

var tcs = new TaskCompletionSource<object>();
tcs.SetException(new[] { t1.Exception, t2.Exception });
return tcs.Task;

当然,如果你再调用await MyWhenAll(..)or MyWhenAll(..).GetAwaiter().GetResult(),那么它只会抛出第一个异常。这符合 的行为Task.WhenAll

这意味着您需要将tcs.Taskup 作为方法的返回值传递,这意味着您的方法不能是async. 你最终会做这样丑陋的事情(根据你的问题调整示例代码):

public static Task MyWhenAll(Task t1, Task t2)
{
    var tcs = new TaskCompletionSource<object>();
    var _ = Impl();
    return tcs.Task;

    async Task Impl()
    {
        await Task.Delay(10);
        try
        {
            await Task.WhenAll(t1, t2);
            tcs.SetResult(null);
        }
        catch
        {
            tcs.SetException(new[] { t1.Exception, t2.Exception });
        }
    }
}

不过,此时我会开始询问您为什么要这样做,以及为什么不能直接使用Task返回的 from Task.WhenAll

于 2019-04-15T15:18:03.420 回答
0

我删除了我之前的答案,因为我找到了一个更简单的解决方案。该解决方案不涉及讨厌的ContinueWith方法或TaskCompletionSource类型。Task<Task>这个想法是从一个本地函数返回一个嵌套函数,并Unwrap()从外部容器函数返回一个嵌套函数。这是这个想法的基本概述:

public Task<T[]> GetAllAsync<T>()
{
    return LocalAsyncFunction().Unwrap();

    async Task<Task<T[]>> LocalAsyncFunction()
    {
        var tasks = new List<Task<T>>();
        // ...
        await SomethingAsync();
        // ...
        Task<T[]> whenAll = Task.WhenAll(tasks);
        return whenAll;
    }
}

方法GetAllAsync不是。async它将所有工作委托给LocalAsyncFunction, 即async, 然后Unwraps 生成的嵌套任务并返回它。展开的任务在其.Exception.InnerExceptions属性中包含 的所有异常tasks,因为它只是内部Task.WhenAll任务的外观。

让我们演示一下这个想法的更实际的实现。下面的AsParallelUntilAsync方法懒惰地枚举source序列并将其包含的项目投影到Task<TResult>s,直到一个项目满足predicate. 它还限制了异步操作的并发性。困难在于枚举IEnumerable<TSource>也可能引发异常。在这种情况下,正确的行为是在传播枚举错误之前等待所有正在运行的任务,并返回AggregateException包含枚举错误以及同时可能发生的所有任务错误的 an。这是如何完成的:

public static Task<TResult[]> AsParallelUntilAsync<TSource, TResult>(
    this IEnumerable<TSource> source, Func<TSource, Task<TResult>> action,
    Func<TSource, bool> predicate, int maxConcurrency)
{
    return Implementation().Unwrap();

    async Task<Task<TResult[]>> Implementation()
    {
        var tasks = new List<Task<TResult>>();

        async Task<TResult> EnumerateAsync()
        {
            var semaphore = new SemaphoreSlim(maxConcurrency, maxConcurrency);
            using var enumerator = source.GetEnumerator();
            while (true)
            {
                await semaphore.WaitAsync();
                if (!enumerator.MoveNext()) break;
                var item = enumerator.Current;
                if (predicate(item)) break;

                async Task<TResult> RunAndRelease(TSource item)
                {
                    try { return await action(item); }
                    finally { semaphore.Release(); }
                }

                tasks.Add(RunAndRelease(item));
            }
            return default; // A dummy value that will never be returned
        }

        Task<TResult> enumerateTask = EnumerateAsync();

        try
        {
            await enumerateTask; // Make sure that the enumeration succeeded
            Task<TResult[]> whenAll = Task.WhenAll(tasks);
            await whenAll; // Make sure that all the tasks succeeded
            return whenAll;
        }
        catch
        {
            // Return a faulted task that contains ALL the errors!
            return Task.WhenAll(tasks.Prepend(enumerateTask));
        }
    }
}
于 2021-09-21T23:07:11.430 回答