From 16c9f8871a8028d78e04b797c119b9092a0fec46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20K=C3=B6llner?= Date: Mon, 20 Apr 2020 08:14:02 +0200 Subject: [PATCH] attempting to fix race cond. with StrictlyOrderedAwaitTask --- Capnp.Net.Runtime.Tests/General.cs | 7 ++++- Capnp.Net.Runtime.Tests/ImpatientTests.cs | 13 +++++++-- .../Mock/TestCapImplementations.cs | 1 + Capnp.Net.Runtime.Tests/TcpRpc.cs | 29 ++++++++++--------- .../TcpRpcAdvancedStuff.cs | 2 +- Capnp.Net.Runtime.Tests/TcpRpcInterop.cs | 4 +-- Capnp.Net.Runtime.Tests/Util/TestBase.cs | 2 +- Capnp.Net.Runtime/Rpc/IPromisedAnswer.cs | 5 ++-- .../Rpc/Interception/CallContext.cs | 10 ++++--- Capnp.Net.Runtime/Rpc/LocalAnswer.cs | 7 +++-- .../Rpc/LocalAnswerCapability.cs | 7 +++-- Capnp.Net.Runtime/Rpc/PendingQuestion.cs | 8 +++-- Capnp.Net.Runtime/Rpc/PromisedCapability.cs | 5 ++-- .../Rpc/RemoteAnswerCapability.cs | 5 ++-- .../Util/StrictlyOrderedAwaitTask.cs | 6 ++-- 15 files changed, 68 insertions(+), 43 deletions(-) diff --git a/Capnp.Net.Runtime.Tests/General.cs b/Capnp.Net.Runtime.Tests/General.cs index 3ca548b..856771b 100644 --- a/Capnp.Net.Runtime.Tests/General.cs +++ b/Capnp.Net.Runtime.Tests/General.cs @@ -48,7 +48,12 @@ namespace Capnp.Net.Runtime.Tests { readonly TaskCompletionSource _tcs = new TaskCompletionSource(); - public Task WhenReturned => _tcs.Task; + public PromisedAnswerMock() + { + WhenReturned = _tcs.Task.EnforceAwaitOrder(); + } + + public StrictlyOrderedAwaitTask WhenReturned { get; } public void Return() { diff --git a/Capnp.Net.Runtime.Tests/ImpatientTests.cs b/Capnp.Net.Runtime.Tests/ImpatientTests.cs index a009a43..f57dad7 100644 --- a/Capnp.Net.Runtime.Tests/ImpatientTests.cs +++ b/Capnp.Net.Runtime.Tests/ImpatientTests.cs @@ -1,5 +1,6 @@ using Capnp.Net.Runtime.Tests.GenImpls; using Capnp.Rpc; +using Capnp.Util; using Capnproto_test.Capnp.Test; using Microsoft.VisualStudio.TestTools.UnitTesting; using System; @@ -140,7 +141,12 @@ namespace Capnp.Net.Runtime.Tests class PromisedAnswerMock : IPromisedAnswer { readonly TaskCompletionSource _tcs = new TaskCompletionSource(); - public Task WhenReturned => _tcs.Task; + public StrictlyOrderedAwaitTask WhenReturned { get; } + + public PromisedAnswerMock() + { + WhenReturned = _tcs.Task.EnforceAwaitOrder(); + } public bool IsTailCall => false; @@ -164,7 +170,7 @@ namespace Capnp.Net.Runtime.Tests { #pragma warning disable CS0618 var answer = new PromisedAnswerMock(); - Assert.ThrowsException(() => Impatient.GetAnswer(answer.WhenReturned)); + Assert.ThrowsException(() => Impatient.GetAnswer(Task.FromResult(new object()))); var t = Impatient.MakePipelineAware(answer, _ => _); Assert.AreEqual(answer, Impatient.GetAnswer(t)); #pragma warning restore CS0618 @@ -174,7 +180,8 @@ namespace Capnp.Net.Runtime.Tests public async Task Access() { var answer = new PromisedAnswerMock(); - var cap = Impatient.Access(answer.WhenReturned, new MemberAccessPath(), Task.FromResult(new TestInterfaceImpl2())); + async Task AwaitReturn() => await answer.WhenReturned; + var cap = Impatient.Access(AwaitReturn(), new MemberAccessPath(), Task.FromResult(new TestInterfaceImpl2())); using (var proxy = new BareProxy(cap)) { await proxy.WhenResolved; diff --git a/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs b/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs index 7f1cac5..5d19630 100644 --- a/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs +++ b/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs @@ -608,6 +608,7 @@ namespace Capnp.Net.Runtime.Tests.GenImpls { lock (_lock) { + Assert.AreEqual(expected, _counter); return Task.FromResult(_counter++); } } diff --git a/Capnp.Net.Runtime.Tests/TcpRpc.cs b/Capnp.Net.Runtime.Tests/TcpRpc.cs index 1285b09..2abdc4e 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpc.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpc.cs @@ -132,7 +132,7 @@ namespace Capnp.Net.Runtime.Tests result.WriteData(0, 654321); mock.Return.SetResult(result); - Assert.IsTrue(answer.WhenReturned.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); var outresult = answer.WhenReturned.Result; Assert.AreEqual(ObjectKind.Struct, outresult.Kind); Assert.AreEqual(654321, outresult.ReadDataInt(0)); @@ -170,7 +170,7 @@ namespace Capnp.Net.Runtime.Tests mock.Return.SetCanceled(); - Assert.IsTrue(Assert.ThrowsExceptionAsync(() => answer.WhenReturned).Wait(MediumNonDbgTimeout)); + Assert.IsTrue(Assert.ThrowsExceptionAsync(async () => await answer.WhenReturned).Wait(MediumNonDbgTimeout)); } } } @@ -266,7 +266,8 @@ namespace Capnp.Net.Runtime.Tests // Even after the client cancelled the call, the server must still send // a response. - Assert.IsTrue(answer.WhenReturned.ContinueWith(t => { }).Wait(MediumNonDbgTimeout)); + async Task AwaitWhenReturned() => await answer.WhenReturned; + Assert.IsTrue(AwaitWhenReturned().ContinueWith(t => { }).Wait(MediumNonDbgTimeout)); } finally { @@ -312,7 +313,7 @@ namespace Capnp.Net.Runtime.Tests mock.Return.SetException(new MyTestException()); - var exTask = Assert.ThrowsExceptionAsync(() => answer.WhenReturned); + var exTask = Assert.ThrowsExceptionAsync(async () => await answer.WhenReturned); Assert.IsTrue(exTask.Wait(MediumNonDbgTimeout)); Assert.IsTrue(exTask.Result.Message.Contains(new MyTestException().Message)); } @@ -367,7 +368,7 @@ namespace Capnp.Net.Runtime.Tests mock.Return.SetResult(result); - Assert.IsTrue(answer.WhenReturned.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); Assert.IsFalse(ct.IsCancellationRequested); Assert.IsTrue(mock2.WhenCalled.Wait(MediumNonDbgTimeout)); @@ -383,7 +384,7 @@ namespace Capnp.Net.Runtime.Tests result2.WriteData(0, 222222); mock2.Return.SetResult(result2); - Assert.IsTrue(answer2.WhenReturned.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer2.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); var outresult2 = answer2.WhenReturned.Result; Assert.AreEqual(ObjectKind.Struct, outresult2.Kind); Assert.AreEqual(222222, outresult2.ReadDataInt(0)); @@ -443,7 +444,7 @@ namespace Capnp.Net.Runtime.Tests using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2)) { - Assert.IsTrue(answer.WhenReturned.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); Assert.IsTrue(mock2.WhenCalled.Wait(MediumNonDbgTimeout)); (var interfaceId2, var methodId2, var inargs2, var ct2) = mock2.WhenCalled.Result; @@ -457,7 +458,7 @@ namespace Capnp.Net.Runtime.Tests result2.WriteData(0, 222222); mock2.Return.SetResult(result2); - Assert.IsTrue(answer2.WhenReturned.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer2.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); var outresult2 = answer2.WhenReturned.Result; Assert.AreEqual(ObjectKind.Struct, outresult2.Kind); Assert.AreEqual(222222, outresult2.ReadDataInt(0)); @@ -521,7 +522,7 @@ namespace Capnp.Net.Runtime.Tests mock.Return.SetResult(result); - Assert.IsTrue(answer.WhenReturned.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); Assert.IsFalse(ct.IsCancellationRequested); var args4 = DynamicSerializerState.CreateForRpc(); @@ -570,10 +571,10 @@ namespace Capnp.Net.Runtime.Tests ret5.WriteData(0, -4); call5.Result.Result.SetResult(ret5); - Assert.IsTrue(answer2.WhenReturned.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(answer3.WhenReturned.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(answer4.WhenReturned.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(answer5.WhenReturned.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer2.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer3.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer4.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(answer5.WhenReturned.WrappedTask.Wait(MediumNonDbgTimeout)); Assert.AreEqual(-1, answer2.WhenReturned.Result.ReadDataInt(0)); Assert.AreEqual(-2, answer3.WhenReturned.Result.ReadDataInt(0)); @@ -686,7 +687,7 @@ namespace Capnp.Net.Runtime.Tests mock.Return.SetResult(result); Assert.IsTrue(Assert.ThrowsExceptionAsync( - () => answer2.WhenReturned).Wait(MediumNonDbgTimeout)); + async () => await answer2.WhenReturned).Wait(MediumNonDbgTimeout)); } } } diff --git a/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs b/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs index e016ad5..0014891 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs @@ -238,7 +238,7 @@ namespace Capnp.Net.Runtime.Tests { var fooTask2 = main2.Foo(123, null); Assert.IsTrue(fooTask2.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(fooTask2.C().GetCallSequence(1).Wait(MediumNonDbgTimeout)); + Assert.IsTrue(fooTask2.C().GetCallSequence(0).Wait(MediumNonDbgTimeout)); } } } diff --git a/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs b/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs index d274d67..6ec007a 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs @@ -423,8 +423,8 @@ namespace Capnp.Net.Runtime.Tests Assert.AreEqual(456u, promise.Result.I); Assert.AreEqual("from TestTailCaller", promise.Result.T); - var dependentCall1 = c.GetCallSequence(0, default); - var dependentCall2 = c.GetCallSequence(0, default); + var dependentCall1 = c.GetCallSequence(1, default); + var dependentCall2 = c.GetCallSequence(2, default); AssertOutput(stdout, "foo"); Assert.IsTrue(dependentCall0.Wait(MediumNonDbgTimeout)); diff --git a/Capnp.Net.Runtime.Tests/Util/TestBase.cs b/Capnp.Net.Runtime.Tests/Util/TestBase.cs index 13a93bb..dc94dd3 100644 --- a/Capnp.Net.Runtime.Tests/Util/TestBase.cs +++ b/Capnp.Net.Runtime.Tests/Util/TestBase.cs @@ -276,7 +276,7 @@ namespace Capnp.Net.Runtime.Tests public void RunTest(Action action) { (_server, _client) = SetupClientServerPair(_options); - Assert.IsTrue(SpinWait.SpinUntil(() => _server.ConnectionCount > 0, 2 * MediumNonDbgTimeout)); + Assert.IsTrue(SpinWait.SpinUntil(() => _server.ConnectionCount > 0, LargeNonDbgTimeout)); var conn = _server.Connections[0]; using (_server) diff --git a/Capnp.Net.Runtime/Rpc/IPromisedAnswer.cs b/Capnp.Net.Runtime/Rpc/IPromisedAnswer.cs index 6c0de2f..5bb9d9c 100644 --- a/Capnp.Net.Runtime/Rpc/IPromisedAnswer.cs +++ b/Capnp.Net.Runtime/Rpc/IPromisedAnswer.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Threading.Tasks; namespace Capnp.Rpc @@ -15,7 +16,7 @@ namespace Capnp.Rpc /// /// Task which will complete when the RPC returns, delivering its result struct. /// - Task WhenReturned { get; } + StrictlyOrderedAwaitTask WhenReturned { get; } /// /// Creates a low-level capability for promise pipelining. diff --git a/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs b/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs index 1caf2b6..b8b3a52 100644 --- a/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs +++ b/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Threading; using System.Threading.Tasks; @@ -18,16 +19,17 @@ namespace Capnp.Rpc.Interception readonly CancellationTokenSource _cancelFromAlice = new CancellationTokenSource(); public PromisedAnswer(CallContext callContext) - { + { _callContext = callContext; + WhenReturned = _futureResult.Task.EnforceAwaitOrder(); } - public Task WhenReturned => _futureResult.Task; + public StrictlyOrderedAwaitTask WhenReturned { get; } public CancellationToken CancelFromAlice => _cancelFromAlice.Token; public ConsumedCapability Access(MemberAccessPath access) { - return _callContext._censorCapability.Policy.Attach(new LocalAnswerCapability(_futureResult.Task, access)); + return _callContext._censorCapability.Policy.Attach(new LocalAnswerCapability(WhenReturned, access)); } public ConsumedCapability Access(MemberAccessPath _, Task task) diff --git a/Capnp.Net.Runtime/Rpc/LocalAnswer.cs b/Capnp.Net.Runtime/Rpc/LocalAnswer.cs index 516b6a7..294412e 100644 --- a/Capnp.Net.Runtime/Rpc/LocalAnswer.cs +++ b/Capnp.Net.Runtime/Rpc/LocalAnswer.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Threading; using System.Threading.Tasks; @@ -11,7 +12,7 @@ namespace Capnp.Rpc public LocalAnswer(CancellationTokenSource cts, Task call) { _cts = cts ?? throw new ArgumentNullException(nameof(cts)); - WhenReturned = call ?? throw new ArgumentNullException(nameof(call)); + WhenReturned = call?.EnforceAwaitOrder() ?? throw new ArgumentNullException(nameof(call)); CleanupAfterReturn(); } @@ -23,7 +24,7 @@ namespace Capnp.Rpc finally { _cts.Dispose(); } } - public Task WhenReturned { get; } + public StrictlyOrderedAwaitTask WhenReturned { get; } public bool IsTailCall => false; diff --git a/Capnp.Net.Runtime/Rpc/LocalAnswerCapability.cs b/Capnp.Net.Runtime/Rpc/LocalAnswerCapability.cs index be55b2c..17d9c0c 100644 --- a/Capnp.Net.Runtime/Rpc/LocalAnswerCapability.cs +++ b/Capnp.Net.Runtime/Rpc/LocalAnswerCapability.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Threading; using System.Threading.Tasks; @@ -7,7 +8,7 @@ namespace Capnp.Rpc class LocalAnswerCapability : RefCountingCapability, IResolvingCapability { - static async Task TransferOwnershipToDummyProxy(Task answer, MemberAccessPath access) + static async Task TransferOwnershipToDummyProxy(StrictlyOrderedAwaitTask answer, MemberAccessPath access) { var result = await answer; var cap = access.Eval(result); @@ -23,7 +24,7 @@ namespace Capnp.Rpc _whenResolvedProxy = proxyTask; } - public LocalAnswerCapability(Task answer, MemberAccessPath access): + public LocalAnswerCapability(StrictlyOrderedAwaitTask answer, MemberAccessPath access): this(TransferOwnershipToDummyProxy(answer, access)) { diff --git a/Capnp.Net.Runtime/Rpc/PendingQuestion.cs b/Capnp.Net.Runtime/Rpc/PendingQuestion.cs index 97be61a..73e074c 100644 --- a/Capnp.Net.Runtime/Rpc/PendingQuestion.cs +++ b/Capnp.Net.Runtime/Rpc/PendingQuestion.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Collections.Generic; using System.Diagnostics; using System.Threading.Tasks; @@ -53,6 +54,7 @@ namespace Capnp.Rpc } readonly TaskCompletionSource _tcs = new TaskCompletionSource(); + readonly StrictlyOrderedAwaitTask _whenReturned; readonly uint _questionId; ConsumedCapability? _target; SerializerState? _inParams; @@ -64,6 +66,8 @@ namespace Capnp.Rpc _questionId = id; _target = target; _inParams = inParams; + _whenReturned = _tcs.Task.EnforceAwaitOrder(); + StateFlags = inParams == null ? State.Sent : State.None; if (target != null) @@ -81,7 +85,7 @@ namespace Capnp.Rpc /// /// Eventually returns the server answer /// - public Task WhenReturned => _tcs.Task; + public StrictlyOrderedAwaitTask WhenReturned => _whenReturned; /// /// Whether this question represents a tail call diff --git a/Capnp.Net.Runtime/Rpc/PromisedCapability.cs b/Capnp.Net.Runtime/Rpc/PromisedCapability.cs index ad8db50..74b5849 100644 --- a/Capnp.Net.Runtime/Rpc/PromisedCapability.cs +++ b/Capnp.Net.Runtime/Rpc/PromisedCapability.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Diagnostics; using System.Threading.Tasks; @@ -72,7 +73,7 @@ namespace Capnp.Rpc return null; } - async void TrackCall(Task call) + async void TrackCall(StrictlyOrderedAwaitTask call) { try { diff --git a/Capnp.Net.Runtime/Rpc/RemoteAnswerCapability.cs b/Capnp.Net.Runtime/Rpc/RemoteAnswerCapability.cs index 53a759d..9b0b248 100644 --- a/Capnp.Net.Runtime/Rpc/RemoteAnswerCapability.cs +++ b/Capnp.Net.Runtime/Rpc/RemoteAnswerCapability.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Threading.Tasks; namespace Capnp.Rpc @@ -38,7 +39,7 @@ namespace Capnp.Rpc { } - async void ReAllowFinishWhenDone(Task task) + async void ReAllowFinishWhenDone(StrictlyOrderedAwaitTask task) { try { diff --git a/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs b/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs index a7c6602..fa0d36c 100644 --- a/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs +++ b/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs @@ -7,7 +7,7 @@ using System.Threading.Tasks; namespace Capnp.Util { - internal class StrictlyOrderedAwaitTask: INotifyCompletion + public class StrictlyOrderedAwaitTask: INotifyCompletion { class Cover { } class Seal { } @@ -99,7 +99,7 @@ namespace Capnp.Util public Task WrappedTask => _awaitedTask; } - internal class StrictlyOrderedAwaitTask : StrictlyOrderedAwaitTask + public class StrictlyOrderedAwaitTask : StrictlyOrderedAwaitTask { public StrictlyOrderedAwaitTask(Task awaitedTask): base(awaitedTask) { @@ -115,7 +115,7 @@ namespace Capnp.Util } - internal static class StrictlyOrderedTaskExtensions + public static class StrictlyOrderedTaskExtensions { public static StrictlyOrderedAwaitTask EnforceAwaitOrder(this Task task) {