fixed deadlock in StrictlyOrderedAwaitTask

This commit is contained in:
Christian Köllner 2020-04-15 08:13:42 +02:00
parent 4441ef6c8e
commit 3252ef97f9
2 changed files with 72 additions and 33 deletions

View File

@ -187,11 +187,44 @@ namespace Capnp.Net.Runtime.Tests
var genTask = Task.Run(() => Generator()); var genTask = Task.Run(() => Generator());
var verTask = Verifier(); var verTask = Verifier();
SpinWait.SpinUntil(() => Volatile.Read(ref counter) >= 100); SpinWait.SpinUntil(() => Volatile.Read(ref counter) >= 100);
tcs.SetResult(impl); Task.Run(() => tcs.SetResult(impl));
cts.Cancel(); cts.Cancel();
Assert.IsTrue(genTask.Wait(MediumNonDbgTimeout)); Assert.IsTrue(genTask.Wait(MediumNonDbgTimeout));
Assert.IsTrue(verTask.Wait(MediumNonDbgTimeout)); Assert.IsTrue(verTask.Wait(MediumNonDbgTimeout));
} }
} }
[TestMethod]
public void EagerRace2()
{
for (int i = 0; i < 100; i++)
{
var tcs1 = new TaskCompletionSource<int>();
var tcs2 = new TaskCompletionSource<int>();
var t1 = Capnp.Util.StrictlyOrderedTaskExtensions.EnforceAwaitOrder(tcs1.Task);
var t2 = Capnp.Util.StrictlyOrderedTaskExtensions.EnforceAwaitOrder(tcs2.Task);
async Task Wait1()
{
await t1;
await t2;
}
async Task Wait2()
{
await t2;
await t1;
}
var w1 = Wait1();
var w2 = Wait2();
Task.Run(() => tcs1.SetResult(0));
Task.Run(() => tcs2.SetResult(0));
Assert.IsTrue(Task.WaitAll(new Task[] { w1, w2 }, MediumNonDbgTimeout));
}
}
} }
} }

View File

@ -7,16 +7,17 @@ using System.Threading.Tasks;
namespace Capnp.Util namespace Capnp.Util
{ {
internal class StrictlyOrderedAwaitTask<T>: INotifyCompletion public class StrictlyOrderedAwaitTask<T>: INotifyCompletion
{ {
static readonly Action Capsule = () => throw new InvalidProgramException("Not invocable");
readonly Task<T> _awaitedTask; readonly Task<T> _awaitedTask;
object? _lock; Action? _state;
long _inOrder, _outOrder;
public StrictlyOrderedAwaitTask(Task<T> awaitedTask) public StrictlyOrderedAwaitTask(Task<T> awaitedTask)
{ {
_awaitedTask = awaitedTask; _awaitedTask = awaitedTask;
_lock = new object(); AwaitInternal();
} }
public StrictlyOrderedAwaitTask<T> GetAwaiter() public StrictlyOrderedAwaitTask<T> GetAwaiter()
@ -24,25 +25,10 @@ namespace Capnp.Util
return this; return this;
} }
public async void OnCompleted(Action continuation) async void AwaitInternal()
{ {
object? safeLock = Volatile.Read(ref _lock);
if (safeLock == null)
{
continuation();
return;
}
long sequence = Interlocked.Increment(ref _inOrder) - 1;
try try
{ {
if (_awaitedTask.IsCompleted)
{
Interlocked.Exchange(ref _lock, null);
}
await _awaitedTask; await _awaitedTask;
} }
catch catch
@ -52,24 +38,44 @@ namespace Capnp.Util
{ {
SpinWait.SpinUntil(() => SpinWait.SpinUntil(() =>
{ {
lock (safeLock) Action? continuations;
do
{ {
if (Volatile.Read(ref _outOrder) != sequence) continuations = Interlocked.Exchange(ref _state, null);
{ continuations?.Invoke();
return false;
}
Interlocked.Increment(ref _outOrder); } while (continuations != null);
continuation(); return Interlocked.CompareExchange(ref _state, Capsule, null) == null;
return true;
}
}); });
} }
} }
public bool IsCompleted => Volatile.Read(ref _lock) == null || (_awaitedTask.IsCompleted && Volatile.Read(ref _inOrder) == 0); public void OnCompleted(Action continuation)
{
SpinWait.SpinUntil(() => {
Action? cur, next;
cur = Volatile.Read(ref _state);
switch (cur)
{
case null:
next = continuation;
break;
case Action capsule when capsule == Capsule:
continuation();
return true;
case Action action:
next = action + continuation;
break;
}
return Interlocked.CompareExchange(ref _state, next, cur) == cur;
});
}
public bool IsCompleted => _awaitedTask.IsCompleted && _state == Capsule;
public T GetResult() => _awaitedTask.GetAwaiter().GetResult(); public T GetResult() => _awaitedTask.GetAwaiter().GetResult();
@ -78,7 +84,7 @@ namespace Capnp.Util
public Task<T> WrappedTask => _awaitedTask; public Task<T> WrappedTask => _awaitedTask;
} }
internal static class StrictlyOrderedTaskExtensions public static class StrictlyOrderedTaskExtensions
{ {
public static StrictlyOrderedAwaitTask<T> EnforceAwaitOrder<T>(this Task<T> task) public static StrictlyOrderedAwaitTask<T> EnforceAwaitOrder<T>(this Task<T> task)
{ {