diff --git a/MoreLinq/Experimental/Await.cs b/MoreLinq/Experimental/Await.cs index c431f2ed0..229fccbcb 100644 --- a/MoreLinq/Experimental/Await.cs +++ b/MoreLinq/Experimental/Await.cs @@ -416,45 +416,124 @@ public static IAwaitQuery AwaitCompletion( return AwaitQuery.Create( - options => _(options.MaxConcurrency ?? int.MaxValue, + options => _(options.MaxConcurrency, options.Scheduler ?? TaskScheduler.Default, options.PreserveOrder)); - IEnumerable _(int maxConcurrency, TaskScheduler scheduler, bool ordered) + IEnumerable _(int? maxConcurrency, TaskScheduler scheduler, bool ordered) { + // A separate task will enumerate the source and launch tasks. + // It will post all progress as notices to the collection below. + // A notice is essentially a discriminated union like: + // + // type Notice<'a, 'b> = + // | End + // | Result of (int * 'a * Task<'b>) + // | Error of ExceptionDispatchInfo + // + // Note that BlockingCollection.CompleteAdding is never used to + // to mark the end (which its own notice above) because + // BlockingCollection.Add throws if called after CompleteAdding + // and we want to deliberately tolerate the race condition. + var notices = new BlockingCollection<(Notice, (int, T, Task), ExceptionDispatchInfo)>(); - var cancellationTokenSource = new CancellationTokenSource(); - var cancellationToken = cancellationTokenSource.Token; - var completed = false; - var enumerator = - source.Index() - .Select(e => (e.Key, Item: e.Value, Task: evaluator(e.Value, cancellationToken))) - .GetEnumerator(); + var consumerCancellationTokenSource = new CancellationTokenSource(); + (Exception, Exception) lastCriticalErrors = default; + + void PostNotice(Notice notice, + (int, T, Task) item, + Exception error) + { + // If a notice fails to post then assume critical error + // conditions (like low memory), capture the error without + // further allocation of resources and trip the cancellation + // token source used by the main loop waiting on notices. + // Note that only the "last" critical error is reported + // as maintaining a list would incur allocations. The idea + // here is to make a best effort attempt to report any of + // the error conditions that may be occuring, which is still + // better than nothing. + + try + { + var edi = error != null + ? ExceptionDispatchInfo.Capture(error) + : null; + notices.Add((notice, item, edi)); + } + catch (Exception e) + { + // Don't use ExceptionDispatchInfo.Capture here to avoid + // inducing allocations if already under low memory + // conditions. + + lastCriticalErrors = (e, error); + consumerCancellationTokenSource.Cancel(); + throw; + } + } + + var completed = false; + var cancellationTokenSource = new CancellationTokenSource(); + var enumerator = source.Index().GetEnumerator(); IDisposable disposable = enumerator; // disables AccessToDisposedClosure warnings try { + var cancellationToken = cancellationTokenSource.Token; + + // Fire-up a parallel loop to iterate through the source and + // launch tasks, posting a result-notice as each task + // completes and another, an end-notice, when all tasks have + // completed. + Task.Factory.StartNew( - () => - CollectToAsync( - enumerator, - e => e.Task, - notices, - (e, r) => (Notice.Result, (e.Key, e.Item, e.Task), default), - ex => (Notice.Error, default, ExceptionDispatchInfo.Capture(ex)), - (Notice.End, default, default), - maxConcurrency, cancellationTokenSource), + async () => + { + try + { + await enumerator.StartAsync( + e => evaluator(e.Value, cancellationToken), + (e, r) => PostNotice(Notice.Result, (e.Key, e.Value, r), default), + () => PostNotice(Notice.End, default, default), + maxConcurrency, cancellationToken); + } + catch (Exception e) + { + PostNotice(Notice.Error, default, e); + } + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, scheduler); + // Remainder here is the main loop that waits for and + // processes notices. + var nextKey = 0; var holds = ordered ? new List<(int, T, Task)>() : null; - foreach (var (kind, result, error) in notices.GetConsumingEnumerable()) + using (var notice = notices.GetConsumingEnumerable(consumerCancellationTokenSource.Token) + .GetEnumerator()) + while (true) { + try + { + if (!notice.MoveNext()) + break; + } + catch (OperationCanceledException e) when (e.CancellationToken == consumerCancellationTokenSource.Token) + { + var (error1, error2) = lastCriticalErrors; + throw new Exception("One or more critical errors have occurred.", + error2 != null ? new AggregateException(error1, error2) + : new AggregateException(error1)); + } + + var (kind, result, error) = notice.Current; + if (kind == Notice.Error) error.Throw(); @@ -531,149 +610,76 @@ IEnumerable _(int maxConcurrency, TaskScheduler scheduler, bool ordered } } - enum Notice { Result, Error, End } - - static async Task CollectToAsync( - this IEnumerator e, - Func> taskSelector, - BlockingCollection collection, - Func, TNotice> completionNoticeSelector, - Func errorNoticeSelector, - TNotice endNotice, - int maxConcurrency, - CancellationTokenSource cancellationTokenSource) + enum Notice { End, Result, Error } + + static async Task StartAsync( + this IEnumerator enumerator, + Func> starter, + Action> onTaskCompletion, + Action onEnd, + int? maxConcurrency, + CancellationToken cancellationToken) { - Reader reader = null; + if (enumerator == null) throw new ArgumentNullException(nameof(enumerator)); + if (starter == null) throw new ArgumentNullException(nameof(starter)); + if (onTaskCompletion == null) throw new ArgumentNullException(nameof(onTaskCompletion)); + if (onEnd == null) throw new ArgumentNullException(nameof(onEnd)); + if (maxConcurrency < 1) throw new ArgumentOutOfRangeException(nameof(maxConcurrency)); - try + using (enumerator) { - reader = new Reader(e); - - var cancellationToken = cancellationTokenSource.Token; - var cancellationTaskSource = new TaskCompletionSource(); - cancellationToken.Register(() => cancellationTaskSource.TrySetResult(true)); + var pendingCount = 1; // terminator - var tasks = new List<(T Item, Task Task)>(); - - for (var i = 0; i < maxConcurrency; i++) + void OnPendingCompleted() { - if (!reader.TryRead(out var item)) - break; - tasks.Add((item, taskSelector(item))); + if (Interlocked.Decrement(ref pendingCount) == 0) + onEnd(); } - while (tasks.Count > 0) + var concurrencyGate = maxConcurrency is int count + ? new ConcurrencyGate(count) + : ConcurrencyGate.Unbounded; + + while (enumerator.MoveNext()) { - // Task.WaitAny is synchronous and blocking but allows the - // waiting to be cancelled via a CancellationToken. - // Task.WhenAny can be awaited so it is better since the - // thread won't be blocked and can return to the pool. - // However, it doesn't support cancellation so instead a - // task is built on top of the CancellationToken that - // completes when the CancellationToken trips. - // - // Also, Task.WhenAny returns the task (Task) object that - // completed but task objects may not be unique due to - // caching, e.g.: - // - // async Task Foo() => true; - // async Task Bar() => true; - // var foo = Foo(); - // var bar = Bar(); - // var same = foo.Equals(bar); // == true - // - // In this case, the task returned by Task.WhenAny will - // match `foo` and `bar`: - // - // var done = Task.WhenAny(foo, bar); - // - // Logically speaking, the uniqueness of a task does not - // matter but here it does, especially when Await (the main - // user of CollectAsync) needs to return results ordered. - // Fortunately, we compose our own task on top of the - // original that links each item with the task result and as - // a consequence generate new and unique task objects. - - var completedTask = await - Task.WhenAny(tasks.Select(it => (Task) it.Task).Concat(cancellationTaskSource.Task)) - .ConfigureAwait(continueOnCapturedContext: false); - - if (completedTask == cancellationTaskSource.Task) + try { - // Cancellation during the wait means the enumeration - // has been stopped by the user so the results of the - // remaining tasks are no longer needed. Those tasks - // should cancel as a result of sharing the same - // cancellation token and provided that they passed it - // on to any downstream asynchronous operations. Either - // way, this loop is done so exit hard here. - - return; + await concurrencyGate.EnterAsync(cancellationToken); } - - var i = tasks.FindIndex(it => it.Task.Equals(completedTask)); - + catch (OperationCanceledException e) when (e.CancellationToken == cancellationToken) { - var (item, task) = tasks[i]; - tasks.RemoveAt(i); + return; + } - // Await the task rather than using its result directly - // to avoid having the task's exception bubble up as - // AggregateException if the task failed. + Interlocked.Increment(ref pendingCount); - collection.Add(completionNoticeSelector(item, task)); - } + var item = enumerator.Current; + var task = starter(item); - { - if (reader.TryRead(out var item)) - tasks.Add((item, taskSelector(item))); - } - } + // Add a continutation that notifies completion of the task, + // along with the necessary housekeeping, in case it + // completes before maximum concurrency is reached. - collection.Add(endNotice); - } - catch (Exception ex) - { - cancellationTokenSource.Cancel(); - collection.Add(errorNoticeSelector(ex)); - } - finally - { - reader?.Dispose(); - } + #pragma warning disable 4014 // https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/compiler-messages/cs4014 - collection.CompleteAdding(); - } + task.ContinueWith(cancellationToken: cancellationToken, + continuationOptions: TaskContinuationOptions.ExecuteSynchronously, + scheduler: TaskScheduler.Current, + continuationAction: t => + { + concurrencyGate.Exit(); - sealed class Reader : IDisposable - { - IEnumerator _enumerator; + if (cancellationToken.IsCancellationRequested) + return; - public Reader(IEnumerator enumerator) => - _enumerator = enumerator; + onTaskCompletion(item, t); + OnPendingCompleted(); + }); - public bool TryRead(out T item) - { - var ended = false; - if (_enumerator == null || (ended = !_enumerator.MoveNext())) - { - if (ended) - Dispose(); - item = default; - return false; + #pragma warning restore 4014 } - item = _enumerator.Current; - return true; - } - - public void Dispose() - { - var e = _enumerator; - if (e == null) - return; - _enumerator = null; - e.Dispose(); + OnPendingCompleted(); } } @@ -720,6 +726,53 @@ static class TupleComparer public static readonly IComparer<(T1, T2, T3)> Item3 = Comparer<(T1, T2, T3)>.Create((x, y) => Comparer.Default.Compare(x.Item3, y.Item3)); } + + static class CompletedTask + { + #if NET451 || NETSTANDARD1_0 + + public static readonly Task Instance; + + static CompletedTask() + { + var tcs = new TaskCompletionSource(); + tcs.SetResult(null); + Instance = tcs.Task; + } + + #else + + public static readonly Task Instance = Task.CompletedTask; + + #endif + } + + sealed class ConcurrencyGate + { + public static readonly ConcurrencyGate Unbounded = new ConcurrencyGate(); + + readonly SemaphoreSlim _semaphore; + + ConcurrencyGate(SemaphoreSlim semaphore = null) => + _semaphore = semaphore; + + public ConcurrencyGate(int max) : + this(new SemaphoreSlim(max, max)) {} + + public Task EnterAsync(CancellationToken token) + { + if (_semaphore == null) + { + token.ThrowIfCancellationRequested(); + return CompletedTask.Instance; + } + + return _semaphore.WaitAsync(token); + } + + public void Exit() => + _semaphore?.Release(); + } } }