From 3252ef97f93f188c61c3c0573b23a327db771ce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20K=C3=B6llner?= Date: Wed, 15 Apr 2020 08:13:42 +0200 Subject: [PATCH] fixed deadlock in StrictlyOrderedAwaitTask --- Capnp.Net.Runtime.Tests/LocalRpc.cs | 35 +++++++++- .../Util/StrictlyOrderedAwaitTask.cs | 70 ++++++++++--------- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/Capnp.Net.Runtime.Tests/LocalRpc.cs b/Capnp.Net.Runtime.Tests/LocalRpc.cs index 55e693a..7517cd2 100644 --- a/Capnp.Net.Runtime.Tests/LocalRpc.cs +++ b/Capnp.Net.Runtime.Tests/LocalRpc.cs @@ -187,11 +187,44 @@ namespace Capnp.Net.Runtime.Tests var genTask = Task.Run(() => Generator()); var verTask = Verifier(); SpinWait.SpinUntil(() => Volatile.Read(ref counter) >= 100); - tcs.SetResult(impl); + Task.Run(() => tcs.SetResult(impl)); cts.Cancel(); Assert.IsTrue(genTask.Wait(MediumNonDbgTimeout)); Assert.IsTrue(verTask.Wait(MediumNonDbgTimeout)); } } + + [TestMethod] + public void EagerRace2() + { + for (int i = 0; i < 100; i++) + { + var tcs1 = new TaskCompletionSource(); + var tcs2 = new TaskCompletionSource(); + + var t1 = Capnp.Util.StrictlyOrderedTaskExtensions.EnforceAwaitOrder(tcs1.Task); + var t2 = Capnp.Util.StrictlyOrderedTaskExtensions.EnforceAwaitOrder(tcs2.Task); + + async Task Wait1() + { + await t1; + await t2; + } + + async Task Wait2() + { + await t2; + await t1; + } + + var w1 = Wait1(); + var w2 = Wait2(); + + Task.Run(() => tcs1.SetResult(0)); + Task.Run(() => tcs2.SetResult(0)); + + Assert.IsTrue(Task.WaitAll(new Task[] { w1, w2 }, MediumNonDbgTimeout)); + } + } } } diff --git a/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs b/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs index f5a26eb..5cf1335 100644 --- a/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs +++ b/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs @@ -7,16 +7,17 @@ using System.Threading.Tasks; namespace Capnp.Util { - internal class StrictlyOrderedAwaitTask: INotifyCompletion + public class StrictlyOrderedAwaitTask: INotifyCompletion { + static readonly Action Capsule = () => throw new InvalidProgramException("Not invocable"); + readonly Task _awaitedTask; - object? _lock; - long _inOrder, _outOrder; + Action? _state; public StrictlyOrderedAwaitTask(Task awaitedTask) { _awaitedTask = awaitedTask; - _lock = new object(); + AwaitInternal(); } public StrictlyOrderedAwaitTask GetAwaiter() @@ -24,25 +25,10 @@ namespace Capnp.Util return this; } - public async void OnCompleted(Action continuation) + async void AwaitInternal() { - object? safeLock = Volatile.Read(ref _lock); - - if (safeLock == null) - { - continuation(); - return; - } - - long sequence = Interlocked.Increment(ref _inOrder) - 1; - try { - if (_awaitedTask.IsCompleted) - { - Interlocked.Exchange(ref _lock, null); - } - await _awaitedTask; } catch @@ -52,24 +38,44 @@ namespace Capnp.Util { SpinWait.SpinUntil(() => { - lock (safeLock) + Action? continuations; + do { - if (Volatile.Read(ref _outOrder) != sequence) - { - return false; - } + continuations = Interlocked.Exchange(ref _state, null); + continuations?.Invoke(); - Interlocked.Increment(ref _outOrder); + } while (continuations != null); - continuation(); - - return true; - } + return Interlocked.CompareExchange(ref _state, Capsule, null) == null; }); } } - public bool IsCompleted => Volatile.Read(ref _lock) == null || (_awaitedTask.IsCompleted && Volatile.Read(ref _inOrder) == 0); + public void OnCompleted(Action continuation) + { + SpinWait.SpinUntil(() => { + Action? cur, next; + cur = Volatile.Read(ref _state); + switch (cur) + { + case null: + next = continuation; + break; + + case Action capsule when capsule == Capsule: + continuation(); + return true; + + case Action action: + next = action + continuation; + break; + } + + return Interlocked.CompareExchange(ref _state, next, cur) == cur; + }); + } + + public bool IsCompleted => _awaitedTask.IsCompleted && _state == Capsule; public T GetResult() => _awaitedTask.GetAwaiter().GetResult(); @@ -78,7 +84,7 @@ namespace Capnp.Util public Task WrappedTask => _awaitedTask; } - internal static class StrictlyOrderedTaskExtensions + public static class StrictlyOrderedTaskExtensions { public static StrictlyOrderedAwaitTask EnforceAwaitOrder(this Task task) {