fixed deadlock in StrictlyOrderedAwaitTask

This commit is contained in:
Christian Köllner 2020-04-15 08:13:42 +02:00
parent 4441ef6c8e
commit 3252ef97f9
2 changed files with 72 additions and 33 deletions

View File

@ -187,11 +187,44 @@ namespace Capnp.Net.Runtime.Tests
var genTask = Task.Run(() => Generator());
var verTask = Verifier();
SpinWait.SpinUntil(() => Volatile.Read(ref counter) >= 100);
tcs.SetResult(impl);
Task.Run(() => tcs.SetResult(impl));
cts.Cancel();
Assert.IsTrue(genTask.Wait(MediumNonDbgTimeout));
Assert.IsTrue(verTask.Wait(MediumNonDbgTimeout));
}
}
[TestMethod]
public void EagerRace2()
{
for (int i = 0; i < 100; i++)
{
var tcs1 = new TaskCompletionSource<int>();
var tcs2 = new TaskCompletionSource<int>();
var t1 = Capnp.Util.StrictlyOrderedTaskExtensions.EnforceAwaitOrder(tcs1.Task);
var t2 = Capnp.Util.StrictlyOrderedTaskExtensions.EnforceAwaitOrder(tcs2.Task);
async Task Wait1()
{
await t1;
await t2;
}
async Task Wait2()
{
await t2;
await t1;
}
var w1 = Wait1();
var w2 = Wait2();
Task.Run(() => tcs1.SetResult(0));
Task.Run(() => tcs2.SetResult(0));
Assert.IsTrue(Task.WaitAll(new Task[] { w1, w2 }, MediumNonDbgTimeout));
}
}
}
}

View File

@ -7,16 +7,17 @@ using System.Threading.Tasks;
namespace Capnp.Util
{
internal class StrictlyOrderedAwaitTask<T>: INotifyCompletion
public class StrictlyOrderedAwaitTask<T>: INotifyCompletion
{
static readonly Action Capsule = () => throw new InvalidProgramException("Not invocable");
readonly Task<T> _awaitedTask;
object? _lock;
long _inOrder, _outOrder;
Action? _state;
public StrictlyOrderedAwaitTask(Task<T> awaitedTask)
{
_awaitedTask = awaitedTask;
_lock = new object();
AwaitInternal();
}
public StrictlyOrderedAwaitTask<T> GetAwaiter()
@ -24,25 +25,10 @@ namespace Capnp.Util
return this;
}
public async void OnCompleted(Action continuation)
async void AwaitInternal()
{
object? safeLock = Volatile.Read(ref _lock);
if (safeLock == null)
{
continuation();
return;
}
long sequence = Interlocked.Increment(ref _inOrder) - 1;
try
{
if (_awaitedTask.IsCompleted)
{
Interlocked.Exchange(ref _lock, null);
}
await _awaitedTask;
}
catch
@ -52,24 +38,44 @@ namespace Capnp.Util
{
SpinWait.SpinUntil(() =>
{
lock (safeLock)
Action? continuations;
do
{
if (Volatile.Read(ref _outOrder) != sequence)
{
return false;
}
continuations = Interlocked.Exchange(ref _state, null);
continuations?.Invoke();
Interlocked.Increment(ref _outOrder);
} while (continuations != null);
continuation();
return true;
}
return Interlocked.CompareExchange(ref _state, Capsule, null) == null;
});
}
}
public bool IsCompleted => Volatile.Read(ref _lock) == null || (_awaitedTask.IsCompleted && Volatile.Read(ref _inOrder) == 0);
public void OnCompleted(Action continuation)
{
SpinWait.SpinUntil(() => {
Action? cur, next;
cur = Volatile.Read(ref _state);
switch (cur)
{
case null:
next = continuation;
break;
case Action capsule when capsule == Capsule:
continuation();
return true;
case Action action:
next = action + continuation;
break;
}
return Interlocked.CompareExchange(ref _state, next, cur) == cur;
});
}
public bool IsCompleted => _awaitedTask.IsCompleted && _state == Capsule;
public T GetResult() => _awaitedTask.GetAwaiter().GetResult();
@ -78,7 +84,7 @@ namespace Capnp.Util
public Task<T> WrappedTask => _awaitedTask;
}
internal static class StrictlyOrderedTaskExtensions
public static class StrictlyOrderedTaskExtensions
{
public static StrictlyOrderedAwaitTask<T> EnforceAwaitOrder<T>(this Task<T> task)
{