attempting to fix race cond. with StrictlyOrderedAwaitTask

This commit is contained in:
Christian Köllner 2020-04-20 08:14:02 +02:00
parent e496a19e3d
commit 16c9f8871a
15 changed files with 68 additions and 43 deletions

View File

@ -48,7 +48,12 @@ namespace Capnp.Net.Runtime.Tests
{ {
readonly TaskCompletionSource<DeserializerState> _tcs = new TaskCompletionSource<DeserializerState>(); readonly TaskCompletionSource<DeserializerState> _tcs = new TaskCompletionSource<DeserializerState>();
public Task<DeserializerState> WhenReturned => _tcs.Task; public PromisedAnswerMock()
{
WhenReturned = _tcs.Task.EnforceAwaitOrder();
}
public StrictlyOrderedAwaitTask<DeserializerState> WhenReturned { get; }
public void Return() public void Return()
{ {

View File

@ -1,5 +1,6 @@
using Capnp.Net.Runtime.Tests.GenImpls; using Capnp.Net.Runtime.Tests.GenImpls;
using Capnp.Rpc; using Capnp.Rpc;
using Capnp.Util;
using Capnproto_test.Capnp.Test; using Capnproto_test.Capnp.Test;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using System; using System;
@ -140,7 +141,12 @@ namespace Capnp.Net.Runtime.Tests
class PromisedAnswerMock : IPromisedAnswer class PromisedAnswerMock : IPromisedAnswer
{ {
readonly TaskCompletionSource<DeserializerState> _tcs = new TaskCompletionSource<DeserializerState>(); readonly TaskCompletionSource<DeserializerState> _tcs = new TaskCompletionSource<DeserializerState>();
public Task<DeserializerState> WhenReturned => _tcs.Task; public StrictlyOrderedAwaitTask<DeserializerState> WhenReturned { get; }
public PromisedAnswerMock()
{
WhenReturned = _tcs.Task.EnforceAwaitOrder();
}
public bool IsTailCall => false; public bool IsTailCall => false;
@ -164,7 +170,7 @@ namespace Capnp.Net.Runtime.Tests
{ {
#pragma warning disable CS0618 #pragma warning disable CS0618
var answer = new PromisedAnswerMock(); var answer = new PromisedAnswerMock();
Assert.ThrowsException<ArgumentException>(() => Impatient.GetAnswer(answer.WhenReturned)); Assert.ThrowsException<ArgumentException>(() => Impatient.GetAnswer(Task.FromResult(new object())));
var t = Impatient.MakePipelineAware(answer, _ => _); var t = Impatient.MakePipelineAware(answer, _ => _);
Assert.AreEqual(answer, Impatient.GetAnswer(t)); Assert.AreEqual(answer, Impatient.GetAnswer(t));
#pragma warning restore CS0618 #pragma warning restore CS0618
@ -174,7 +180,8 @@ namespace Capnp.Net.Runtime.Tests
public async Task Access() public async Task Access()
{ {
var answer = new PromisedAnswerMock(); var answer = new PromisedAnswerMock();
var cap = Impatient.Access(answer.WhenReturned, new MemberAccessPath(), Task.FromResult<IDisposable>(new TestInterfaceImpl2())); async Task AwaitReturn() => await answer.WhenReturned;
var cap = Impatient.Access(AwaitReturn(), new MemberAccessPath(), Task.FromResult<IDisposable>(new TestInterfaceImpl2()));
using (var proxy = new BareProxy(cap)) using (var proxy = new BareProxy(cap))
{ {
await proxy.WhenResolved; await proxy.WhenResolved;

View File

@ -608,6 +608,7 @@ namespace Capnp.Net.Runtime.Tests.GenImpls
{ {
lock (_lock) lock (_lock)
{ {
Assert.AreEqual(expected, _counter);
return Task.FromResult(_counter++); return Task.FromResult(_counter++);
} }
} }

View File

@ -132,7 +132,7 @@ namespace Capnp.Net.Runtime.Tests
result.WriteData(0, 654321); result.WriteData(0, 654321);
mock.Return.SetResult(result); mock.Return.SetResult(result);
Assert.IsTrue(answer.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
var outresult = answer.WhenReturned.Result; var outresult = answer.WhenReturned.Result;
Assert.AreEqual(ObjectKind.Struct, outresult.Kind); Assert.AreEqual(ObjectKind.Struct, outresult.Kind);
Assert.AreEqual(654321, outresult.ReadDataInt(0)); Assert.AreEqual(654321, outresult.ReadDataInt(0));
@ -170,7 +170,7 @@ namespace Capnp.Net.Runtime.Tests
mock.Return.SetCanceled(); mock.Return.SetCanceled();
Assert.IsTrue(Assert.ThrowsExceptionAsync<TaskCanceledException>(() => answer.WhenReturned).Wait(MediumNonDbgTimeout)); Assert.IsTrue(Assert.ThrowsExceptionAsync<TaskCanceledException>(async () => await answer.WhenReturned).Wait(MediumNonDbgTimeout));
} }
} }
} }
@ -266,7 +266,8 @@ namespace Capnp.Net.Runtime.Tests
// Even after the client cancelled the call, the server must still send // Even after the client cancelled the call, the server must still send
// a response. // a response.
Assert.IsTrue(answer.WhenReturned.ContinueWith(t => { }).Wait(MediumNonDbgTimeout)); async Task AwaitWhenReturned() => await answer.WhenReturned;
Assert.IsTrue(AwaitWhenReturned().ContinueWith(t => { }).Wait(MediumNonDbgTimeout));
} }
finally finally
{ {
@ -312,7 +313,7 @@ namespace Capnp.Net.Runtime.Tests
mock.Return.SetException(new MyTestException()); mock.Return.SetException(new MyTestException());
var exTask = Assert.ThrowsExceptionAsync<RpcException>(() => answer.WhenReturned); var exTask = Assert.ThrowsExceptionAsync<RpcException>(async () => await answer.WhenReturned);
Assert.IsTrue(exTask.Wait(MediumNonDbgTimeout)); Assert.IsTrue(exTask.Wait(MediumNonDbgTimeout));
Assert.IsTrue(exTask.Result.Message.Contains(new MyTestException().Message)); Assert.IsTrue(exTask.Result.Message.Contains(new MyTestException().Message));
} }
@ -367,7 +368,7 @@ namespace Capnp.Net.Runtime.Tests
mock.Return.SetResult(result); mock.Return.SetResult(result);
Assert.IsTrue(answer.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
Assert.IsFalse(ct.IsCancellationRequested); Assert.IsFalse(ct.IsCancellationRequested);
Assert.IsTrue(mock2.WhenCalled.Wait(MediumNonDbgTimeout)); Assert.IsTrue(mock2.WhenCalled.Wait(MediumNonDbgTimeout));
@ -383,7 +384,7 @@ namespace Capnp.Net.Runtime.Tests
result2.WriteData(0, 222222); result2.WriteData(0, 222222);
mock2.Return.SetResult(result2); mock2.Return.SetResult(result2);
Assert.IsTrue(answer2.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer2.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
var outresult2 = answer2.WhenReturned.Result; var outresult2 = answer2.WhenReturned.Result;
Assert.AreEqual(ObjectKind.Struct, outresult2.Kind); Assert.AreEqual(ObjectKind.Struct, outresult2.Kind);
Assert.AreEqual(222222, outresult2.ReadDataInt(0)); Assert.AreEqual(222222, outresult2.ReadDataInt(0));
@ -443,7 +444,7 @@ namespace Capnp.Net.Runtime.Tests
using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2)) using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2))
{ {
Assert.IsTrue(answer.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
Assert.IsTrue(mock2.WhenCalled.Wait(MediumNonDbgTimeout)); Assert.IsTrue(mock2.WhenCalled.Wait(MediumNonDbgTimeout));
(var interfaceId2, var methodId2, var inargs2, var ct2) = mock2.WhenCalled.Result; (var interfaceId2, var methodId2, var inargs2, var ct2) = mock2.WhenCalled.Result;
@ -457,7 +458,7 @@ namespace Capnp.Net.Runtime.Tests
result2.WriteData(0, 222222); result2.WriteData(0, 222222);
mock2.Return.SetResult(result2); mock2.Return.SetResult(result2);
Assert.IsTrue(answer2.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer2.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
var outresult2 = answer2.WhenReturned.Result; var outresult2 = answer2.WhenReturned.Result;
Assert.AreEqual(ObjectKind.Struct, outresult2.Kind); Assert.AreEqual(ObjectKind.Struct, outresult2.Kind);
Assert.AreEqual(222222, outresult2.ReadDataInt(0)); Assert.AreEqual(222222, outresult2.ReadDataInt(0));
@ -521,7 +522,7 @@ namespace Capnp.Net.Runtime.Tests
mock.Return.SetResult(result); mock.Return.SetResult(result);
Assert.IsTrue(answer.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
Assert.IsFalse(ct.IsCancellationRequested); Assert.IsFalse(ct.IsCancellationRequested);
var args4 = DynamicSerializerState.CreateForRpc(); var args4 = DynamicSerializerState.CreateForRpc();
@ -570,10 +571,10 @@ namespace Capnp.Net.Runtime.Tests
ret5.WriteData(0, -4); ret5.WriteData(0, -4);
call5.Result.Result.SetResult(ret5); call5.Result.Result.SetResult(ret5);
Assert.IsTrue(answer2.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer2.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
Assert.IsTrue(answer3.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer3.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
Assert.IsTrue(answer4.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer4.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
Assert.IsTrue(answer5.WhenReturned.Wait(MediumNonDbgTimeout)); Assert.IsTrue(answer5.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout));
Assert.AreEqual(-1, answer2.WhenReturned.Result.ReadDataInt(0)); Assert.AreEqual(-1, answer2.WhenReturned.Result.ReadDataInt(0));
Assert.AreEqual(-2, answer3.WhenReturned.Result.ReadDataInt(0)); Assert.AreEqual(-2, answer3.WhenReturned.Result.ReadDataInt(0));
@ -686,7 +687,7 @@ namespace Capnp.Net.Runtime.Tests
mock.Return.SetResult(result); mock.Return.SetResult(result);
Assert.IsTrue(Assert.ThrowsExceptionAsync<TaskCanceledException>( Assert.IsTrue(Assert.ThrowsExceptionAsync<TaskCanceledException>(
() => answer2.WhenReturned).Wait(MediumNonDbgTimeout)); async () => await answer2.WhenReturned).Wait(MediumNonDbgTimeout));
} }
} }
} }

View File

@ -238,7 +238,7 @@ namespace Capnp.Net.Runtime.Tests
{ {
var fooTask2 = main2.Foo(123, null); var fooTask2 = main2.Foo(123, null);
Assert.IsTrue(fooTask2.Wait(MediumNonDbgTimeout)); Assert.IsTrue(fooTask2.Wait(MediumNonDbgTimeout));
Assert.IsTrue(fooTask2.C().GetCallSequence(1).Wait(MediumNonDbgTimeout)); Assert.IsTrue(fooTask2.C().GetCallSequence(0).Wait(MediumNonDbgTimeout));
} }
} }
} }

View File

@ -423,8 +423,8 @@ namespace Capnp.Net.Runtime.Tests
Assert.AreEqual(456u, promise.Result.I); Assert.AreEqual(456u, promise.Result.I);
Assert.AreEqual("from TestTailCaller", promise.Result.T); Assert.AreEqual("from TestTailCaller", promise.Result.T);
var dependentCall1 = c.GetCallSequence(0, default); var dependentCall1 = c.GetCallSequence(1, default);
var dependentCall2 = c.GetCallSequence(0, default); var dependentCall2 = c.GetCallSequence(2, default);
AssertOutput(stdout, "foo"); AssertOutput(stdout, "foo");
Assert.IsTrue(dependentCall0.Wait(MediumNonDbgTimeout)); Assert.IsTrue(dependentCall0.Wait(MediumNonDbgTimeout));

View File

@ -276,7 +276,7 @@ namespace Capnp.Net.Runtime.Tests
public void RunTest(Action<ITestbed> action) public void RunTest(Action<ITestbed> action)
{ {
(_server, _client) = SetupClientServerPair(_options); (_server, _client) = SetupClientServerPair(_options);
Assert.IsTrue(SpinWait.SpinUntil(() => _server.ConnectionCount > 0, 2 * MediumNonDbgTimeout)); Assert.IsTrue(SpinWait.SpinUntil(() => _server.ConnectionCount > 0, LargeNonDbgTimeout));
var conn = _server.Connections[0]; var conn = _server.Connections[0];
using (_server) using (_server)

View File

@ -1,4 +1,5 @@
using System; using Capnp.Util;
using System;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Capnp.Rpc namespace Capnp.Rpc
@ -15,7 +16,7 @@ namespace Capnp.Rpc
/// <summary> /// <summary>
/// Task which will complete when the RPC returns, delivering its result struct. /// Task which will complete when the RPC returns, delivering its result struct.
/// </summary> /// </summary>
Task<DeserializerState> WhenReturned { get; } StrictlyOrderedAwaitTask<DeserializerState> WhenReturned { get; }
/// <summary> /// <summary>
/// Creates a low-level capability for promise pipelining. /// Creates a low-level capability for promise pipelining.

View File

@ -1,4 +1,5 @@
using System; using Capnp.Util;
using System;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -18,16 +19,17 @@ namespace Capnp.Rpc.Interception
readonly CancellationTokenSource _cancelFromAlice = new CancellationTokenSource(); readonly CancellationTokenSource _cancelFromAlice = new CancellationTokenSource();
public PromisedAnswer(CallContext callContext) public PromisedAnswer(CallContext callContext)
{ {
_callContext = callContext; _callContext = callContext;
WhenReturned = _futureResult.Task.EnforceAwaitOrder();
} }
public Task<DeserializerState> WhenReturned => _futureResult.Task; public StrictlyOrderedAwaitTask<DeserializerState> WhenReturned { get; }
public CancellationToken CancelFromAlice => _cancelFromAlice.Token; public CancellationToken CancelFromAlice => _cancelFromAlice.Token;
public ConsumedCapability Access(MemberAccessPath access) public ConsumedCapability Access(MemberAccessPath access)
{ {
return _callContext._censorCapability.Policy.Attach<ConsumedCapability>(new LocalAnswerCapability(_futureResult.Task, access)); return _callContext._censorCapability.Policy.Attach<ConsumedCapability>(new LocalAnswerCapability(WhenReturned, access));
} }
public ConsumedCapability Access(MemberAccessPath _, Task<IDisposable?> task) public ConsumedCapability Access(MemberAccessPath _, Task<IDisposable?> task)

View File

@ -1,4 +1,5 @@
using System; using Capnp.Util;
using System;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -11,7 +12,7 @@ namespace Capnp.Rpc
public LocalAnswer(CancellationTokenSource cts, Task<DeserializerState> call) public LocalAnswer(CancellationTokenSource cts, Task<DeserializerState> call)
{ {
_cts = cts ?? throw new ArgumentNullException(nameof(cts)); _cts = cts ?? throw new ArgumentNullException(nameof(cts));
WhenReturned = call ?? throw new ArgumentNullException(nameof(call)); WhenReturned = call?.EnforceAwaitOrder() ?? throw new ArgumentNullException(nameof(call));
CleanupAfterReturn(); CleanupAfterReturn();
} }
@ -23,7 +24,7 @@ namespace Capnp.Rpc
finally { _cts.Dispose(); } finally { _cts.Dispose(); }
} }
public Task<DeserializerState> WhenReturned { get; } public StrictlyOrderedAwaitTask<DeserializerState> WhenReturned { get; }
public bool IsTailCall => false; public bool IsTailCall => false;

View File

@ -1,4 +1,5 @@
using System; using Capnp.Util;
using System;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -7,7 +8,7 @@ namespace Capnp.Rpc
class LocalAnswerCapability : RefCountingCapability, IResolvingCapability class LocalAnswerCapability : RefCountingCapability, IResolvingCapability
{ {
static async Task<Proxy> TransferOwnershipToDummyProxy(Task<DeserializerState> answer, MemberAccessPath access) static async Task<Proxy> TransferOwnershipToDummyProxy(StrictlyOrderedAwaitTask<DeserializerState> answer, MemberAccessPath access)
{ {
var result = await answer; var result = await answer;
var cap = access.Eval(result); var cap = access.Eval(result);
@ -23,7 +24,7 @@ namespace Capnp.Rpc
_whenResolvedProxy = proxyTask; _whenResolvedProxy = proxyTask;
} }
public LocalAnswerCapability(Task<DeserializerState> answer, MemberAccessPath access): public LocalAnswerCapability(StrictlyOrderedAwaitTask<DeserializerState> answer, MemberAccessPath access):
this(TransferOwnershipToDummyProxy(answer, access)) this(TransferOwnershipToDummyProxy(answer, access))
{ {

View File

@ -1,4 +1,5 @@
using System; using Capnp.Util;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -53,6 +54,7 @@ namespace Capnp.Rpc
} }
readonly TaskCompletionSource<DeserializerState> _tcs = new TaskCompletionSource<DeserializerState>(); readonly TaskCompletionSource<DeserializerState> _tcs = new TaskCompletionSource<DeserializerState>();
readonly StrictlyOrderedAwaitTask<DeserializerState> _whenReturned;
readonly uint _questionId; readonly uint _questionId;
ConsumedCapability? _target; ConsumedCapability? _target;
SerializerState? _inParams; SerializerState? _inParams;
@ -64,6 +66,8 @@ namespace Capnp.Rpc
_questionId = id; _questionId = id;
_target = target; _target = target;
_inParams = inParams; _inParams = inParams;
_whenReturned = _tcs.Task.EnforceAwaitOrder();
StateFlags = inParams == null ? State.Sent : State.None; StateFlags = inParams == null ? State.Sent : State.None;
if (target != null) if (target != null)
@ -81,7 +85,7 @@ namespace Capnp.Rpc
/// <summary> /// <summary>
/// Eventually returns the server answer /// Eventually returns the server answer
/// </summary> /// </summary>
public Task<DeserializerState> WhenReturned => _tcs.Task; public StrictlyOrderedAwaitTask<DeserializerState> WhenReturned => _whenReturned;
/// <summary> /// <summary>
/// Whether this question represents a tail call /// Whether this question represents a tail call

View File

@ -1,4 +1,5 @@
using System; using Capnp.Util;
using System;
using System.Diagnostics; using System.Diagnostics;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -72,7 +73,7 @@ namespace Capnp.Rpc
return null; return null;
} }
async void TrackCall(Task call) async void TrackCall(StrictlyOrderedAwaitTask call)
{ {
try try
{ {

View File

@ -1,4 +1,5 @@
using System; using Capnp.Util;
using System;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Capnp.Rpc namespace Capnp.Rpc
@ -38,7 +39,7 @@ namespace Capnp.Rpc
{ {
} }
async void ReAllowFinishWhenDone(Task task) async void ReAllowFinishWhenDone(StrictlyOrderedAwaitTask task)
{ {
try try
{ {

View File

@ -7,7 +7,7 @@ using System.Threading.Tasks;
namespace Capnp.Util namespace Capnp.Util
{ {
internal class StrictlyOrderedAwaitTask: INotifyCompletion public class StrictlyOrderedAwaitTask: INotifyCompletion
{ {
class Cover { } class Cover { }
class Seal { } class Seal { }
@ -99,7 +99,7 @@ namespace Capnp.Util
public Task WrappedTask => _awaitedTask; public Task WrappedTask => _awaitedTask;
} }
internal class StrictlyOrderedAwaitTask<T> : StrictlyOrderedAwaitTask public class StrictlyOrderedAwaitTask<T> : StrictlyOrderedAwaitTask
{ {
public StrictlyOrderedAwaitTask(Task<T> awaitedTask): base(awaitedTask) public StrictlyOrderedAwaitTask(Task<T> awaitedTask): base(awaitedTask)
{ {
@ -115,7 +115,7 @@ namespace Capnp.Util
} }
internal static class StrictlyOrderedTaskExtensions public static class StrictlyOrderedTaskExtensions
{ {
public static StrictlyOrderedAwaitTask<T> EnforceAwaitOrder<T>(this Task<T> task) public static StrictlyOrderedAwaitTask<T> EnforceAwaitOrder<T>(this Task<T> task)
{ {