diff --git a/Capnp.Net.Runtime.Tests/DeserializationTests.cs b/Capnp.Net.Runtime.Tests/DeserializationTests.cs index 1f286b7..e90269d 100644 --- a/Capnp.Net.Runtime.Tests/DeserializationTests.cs +++ b/Capnp.Net.Runtime.Tests/DeserializationTests.cs @@ -151,6 +151,9 @@ namespace Capnp.Net.Runtime.Tests Assert.AreEqual(2, asListOfStructs.Count); Assert.AreEqual(0ul, asListOfStructs[0].ReadDataULong(0)); Assert.AreEqual(ulong.MaxValue, asListOfStructs[1].ReadDataULong(0)); + Assert.ThrowsException(() => asListOfStructs[-1].ReadDataUShort(0)); + Assert.ThrowsException(() => asListOfStructs[3].ReadDataUShort(0)); + CollectionAssert.AreEqual(new ulong[] { 0, ulong.MaxValue }, asListOfStructs.Select(_ => _.ReadDataULong(0)).ToArray()); } [TestMethod] diff --git a/Capnp.Net.Runtime.Tests/General.cs b/Capnp.Net.Runtime.Tests/General.cs index f5cb811..0b564f6 100644 --- a/Capnp.Net.Runtime.Tests/General.cs +++ b/Capnp.Net.Runtime.Tests/General.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Dataflow; namespace Capnp.Net.Runtime.Tests { @@ -42,33 +43,6 @@ namespace Capnp.Net.Runtime.Tests Task.WhenAll(tasks).Wait(); } - [TestMethod] - public void AwaitOrderTest2() - { - int returnCounter = 0; - - async Task ExpectCount(Task task, int count) - { - await task; - Assert.AreEqual(count, returnCounter++); - } - - var tcs = new TaskCompletionSource(); - var cts = new CancellationTokenSource(); - - var tasks = - from i in Enumerable.Range(0, 100) - select ExpectCount(tcs.Task.ContinueWith( - t => t, - cts.Token, - TaskContinuationOptions.ExecuteSynchronously, - TaskScheduler.Current), i); - - tcs.SetResult(0); - - Task.WhenAll(tasks).Wait(); - } - class PromisedAnswerMock : IPromisedAnswer { readonly TaskCompletionSource _tcs = new TaskCompletionSource(); diff --git a/Capnp.Net.Runtime.Tests/Interception.cs b/Capnp.Net.Runtime.Tests/Interception.cs index f2f7b71..0e0e010 100644 --- a/Capnp.Net.Runtime.Tests/Interception.cs +++ b/Capnp.Net.Runtime.Tests/Interception.cs @@ -67,7 +67,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = policy.Attach(new TestInterfaceImpl(counters)); @@ -106,7 +106,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = policy.Attach(new TestInterfaceImpl(counters)); @@ -146,7 +146,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestInterfaceImpl(counters); @@ -204,7 +204,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestInterfaceImpl(counters); @@ -237,7 +237,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestInterfaceImpl(counters); @@ -268,7 +268,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestInterfaceImpl(counters); @@ -334,7 +334,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestInterfaceImpl(counters); @@ -369,7 +369,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestInterfaceImpl(counters); @@ -406,7 +406,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestInterfaceImpl(counters); @@ -431,7 +431,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestTailCallerImpl(counters); @@ -463,7 +463,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestMoreStuffImpl(counters); @@ -503,7 +503,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = policy.Attach(new TestMoreStuffImpl(counters)); @@ -560,7 +560,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestMoreStuffImpl(counters); @@ -628,7 +628,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestMoreStuffImpl(counters); @@ -668,7 +668,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); server.Main = implAc; using (var main = client.GetMain()) diff --git a/Capnp.Net.Runtime.Tests/LocalRpc.cs b/Capnp.Net.Runtime.Tests/LocalRpc.cs index 7adc180..55e693a 100644 --- a/Capnp.Net.Runtime.Tests/LocalRpc.cs +++ b/Capnp.Net.Runtime.Tests/LocalRpc.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Dataflow; namespace Capnp.Net.Runtime.Tests { @@ -138,5 +139,59 @@ namespace Capnp.Net.Runtime.Tests { NewLocalTestbed().RunTest(Testsuite.Ownership3); } + + [TestMethod] + public void EagerRace() + { + var impl = new TestMoreStuffImpl(new Counters()); + var tcs = new TaskCompletionSource(); + using (var promise = tcs.Task.Eager(true)) + using (var cts = new CancellationTokenSource()) + { + var bb = new BufferBlock>(); + int counter = 0; + + void Generator() + { + while (!cts.IsCancellationRequested) + { + bb.Post(promise.GetCallSequence((uint)Volatile.Read(ref counter))); + Interlocked.Increment(ref counter); + } + + bb.Complete(); + } + + async Task Verifier() + { + uint i = 0; + while (true) + { + Task t; + + try + { + t = await bb.ReceiveAsync(); + } + catch (InvalidOperationException) + { + break; + } + + uint j = await t; + Assert.AreEqual(i, j); + i++; + } + } + + var genTask = Task.Run(() => Generator()); + var verTask = Verifier(); + SpinWait.SpinUntil(() => Volatile.Read(ref counter) >= 100); + tcs.SetResult(impl); + cts.Cancel(); + Assert.IsTrue(genTask.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(verTask.Wait(MediumNonDbgTimeout)); + } + } } } diff --git a/Capnp.Net.Runtime.Tests/TcpRpc.cs b/Capnp.Net.Runtime.Tests/TcpRpc.cs index 009c94d..af749ba 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpc.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpc.cs @@ -68,7 +68,7 @@ namespace Capnp.Net.Runtime.Tests { try { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); } @@ -97,7 +97,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -116,7 +116,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -134,7 +134,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -175,7 +175,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -214,7 +214,7 @@ namespace Capnp.Net.Runtime.Tests { try { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -258,7 +258,7 @@ namespace Capnp.Net.Runtime.Tests { try { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -317,7 +317,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -354,7 +354,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -427,7 +427,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -503,7 +503,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -620,7 +620,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); @@ -667,7 +667,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); SpinWait.SpinUntil(() => server.ConnectionCount > 0, MediumNonDbgTimeout); Assert.AreEqual(1, server.ConnectionCount); diff --git a/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs b/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs index 6d41fc9..1349fd5 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs @@ -26,7 +26,7 @@ namespace Capnp.Net.Runtime.Tests { using (var client = SetupClient()) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -62,8 +62,8 @@ namespace Capnp.Net.Runtime.Tests using (var client1 = SetupClient()) using (var client2 = SetupClient()) { - Assert.IsTrue(client1.WhenConnected.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(client2.WhenConnected.Wait(MediumNonDbgTimeout)); + //Assert.IsTrue(client1.WhenConnected.Wait(MediumNonDbgTimeout)); + //Assert.IsTrue(client2.WhenConnected.Wait(MediumNonDbgTimeout)); using (var main = client1.GetMain()) { @@ -132,7 +132,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = SetupClient()) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -154,7 +154,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = SetupClient()) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -184,7 +184,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = SetupClient()) { - Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); + //Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); using (var main = client.GetMain()) { @@ -196,7 +196,7 @@ namespace Capnp.Net.Runtime.Tests using (var client2 = SetupClient()) { - Assert.IsTrue(client2.WhenConnected.Wait(MediumNonDbgTimeout)); + //Assert.IsTrue(client2.WhenConnected.Wait(MediumNonDbgTimeout)); using (var main2 = client2.GetMain()) { @@ -221,7 +221,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = SetupClient()) { - Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); + //Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); using (var main = client.GetMain()) { @@ -232,7 +232,7 @@ namespace Capnp.Net.Runtime.Tests using (var c = fooTask.Result.C) using (var client2 = SetupClient()) { - Assert.IsTrue(client2.WhenConnected.Wait(MediumNonDbgTimeout)); + //Assert.IsTrue(client2.WhenConnected.Wait(MediumNonDbgTimeout)); using (var main2 = client2.GetMain()) { @@ -255,7 +255,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = SetupClient()) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { diff --git a/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs b/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs index 7ad9852..d274d67 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs @@ -135,7 +135,7 @@ namespace Capnp.Net.Runtime.Tests { using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -185,7 +185,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -239,7 +239,7 @@ namespace Capnp.Net.Runtime.Tests { using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -305,7 +305,7 @@ namespace Capnp.Net.Runtime.Tests { using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -407,7 +407,7 @@ namespace Capnp.Net.Runtime.Tests client.AttachTracer(tracer); client.Connect("localhost", TcpPort); - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -475,7 +475,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -520,7 +520,7 @@ namespace Capnp.Net.Runtime.Tests { using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -589,7 +589,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -682,7 +682,7 @@ namespace Capnp.Net.Runtime.Tests { using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var destructionPromise = new TaskCompletionSource(); var destructionTask = destructionPromise.Task; @@ -740,7 +740,7 @@ namespace Capnp.Net.Runtime.Tests { using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -802,36 +802,15 @@ namespace Capnp.Net.Runtime.Tests { LaunchCompatTestProcess("server:MoreStuff", stdout => { - int retry = 0; - - label: using (var client = new TcpRpcClient("localhost", TcpPort)) { - Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout), "client connect"); + //Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout), "client connect"); - using (var main = client.GetMain()) + using (var wrapped = client.GetMain()) { - var resolving = main as IResolvingCapability; - - bool success; - - try - { - success = resolving.WhenResolved.Wait(MediumNonDbgTimeout); - } - catch - { - success = false; - } - - if (!success) - { - if (++retry == 5) - { - Assert.Fail("Attempting to obtain bootstrap interface failed. Bailing out."); - } - goto label; - } + var unwrap = wrapped.Unwrap(); + Assert.IsTrue(unwrap.Wait(MediumNonDbgTimeout)); + var main = unwrap.Result; var cap = new TestCallOrderImpl(); cap.CountToDispose = 6; @@ -892,7 +871,7 @@ namespace Capnp.Net.Runtime.Tests label: using (var client = new TcpRpcClient("localhost", TcpPort)) { - Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout), "client connect"); + //Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout), "client connect"); using (var main = client.GetMain()) { @@ -986,7 +965,7 @@ namespace Capnp.Net.Runtime.Tests using (var client = new TcpRpcClient("localhost", TcpPort)) { - Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); + //Assert.IsTrue(client.WhenConnected.Wait(MediumNonDbgTimeout)); using (var main = client.GetMain()) { @@ -1079,7 +1058,7 @@ namespace Capnp.Net.Runtime.Tests label: using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { @@ -1152,7 +1131,7 @@ namespace Capnp.Net.Runtime.Tests { using (var client = new TcpRpcClient("localhost", TcpPort)) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); using (var main = client.GetMain()) { diff --git a/Capnp.Net.Runtime.Tests/TcpRpcPorted.cs b/Capnp.Net.Runtime.Tests/TcpRpcPorted.cs index 9e7cd2a..20fc392 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcPorted.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcPorted.cs @@ -43,7 +43,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestMoreStuffImpl(counters); @@ -151,7 +151,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); server.Main = impl; for (int i = 0; i < 10; i++) diff --git a/Capnp.Net.Runtime.Tests/TcpRpcStress.cs b/Capnp.Net.Runtime.Tests/TcpRpcStress.cs index 7fb4590..43a3cbd 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcStress.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcStress.cs @@ -35,7 +35,7 @@ namespace Capnp.Net.Runtime.Tests using (server) using (client) { - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); var impl = new TestMoreStuffImpl(counters); @@ -109,7 +109,7 @@ namespace Capnp.Net.Runtime.Tests server.InjectMidlayer(s => new ScatteringStream(s, 7)); client.InjectMidlayer(s => new ScatteringStream(s, 10)); client.Connect("localhost", TcpPort); - client.WhenConnected.Wait(); + //client.WhenConnected.Wait(); var counters = new Counters(); server.Main = new TestInterfaceImpl(counters); diff --git a/Capnp.Net.Runtime.Tests/Util/TestBase.cs b/Capnp.Net.Runtime.Tests/Util/TestBase.cs index 5f13b70..5770c2c 100644 --- a/Capnp.Net.Runtime.Tests/Util/TestBase.cs +++ b/Capnp.Net.Runtime.Tests/Util/TestBase.cs @@ -15,7 +15,7 @@ namespace Capnp.Net.Runtime.Tests { public interface ITestbed { - T ConnectMain(object main) where T : class; + T ConnectMain(object main) where T : class, IDisposable; void MustComplete(params Task[] tasks); void MustNotComplete(params Task[] tasks); void FlushCommunication(); @@ -241,7 +241,7 @@ namespace Capnp.Net.Runtime.Tests public void RunTest(Action action) { (_server, _client) = SetupClientServerPair(); - _client.WhenConnected.Wait(MediumNonDbgTimeout); + //_client.WhenConnected.Wait(MediumNonDbgTimeout); Assert.IsTrue(SpinWait.SpinUntil(() => _server.ConnectionCount > 0, MediumNonDbgTimeout)); var conn = _server.Connections[0]; diff --git a/Capnp.Net.Runtime/ListOfPrimitivesDeserializer.cs b/Capnp.Net.Runtime/ListOfPrimitivesDeserializer.cs index 072d371..a2f4616 100644 --- a/Capnp.Net.Runtime/ListOfPrimitivesDeserializer.cs +++ b/Capnp.Net.Runtime/ListOfPrimitivesDeserializer.cs @@ -31,7 +31,7 @@ namespace Capnp var state = _lpd.State; if (index < 0 || index >= _lpd.Count) - throw new ArgumentOutOfRangeException(nameof(index)); + throw new IndexOutOfRangeException(); state.Offset += index; state.Kind = ObjectKind.Struct; diff --git a/Capnp.Net.Runtime/Rpc/LazyCapability.cs b/Capnp.Net.Runtime/Rpc/LazyCapability.cs index 9794e3f..bd3f037 100644 --- a/Capnp.Net.Runtime/Rpc/LazyCapability.cs +++ b/Capnp.Net.Runtime/Rpc/LazyCapability.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Threading; using System.Threading.Tasks; @@ -18,11 +19,11 @@ namespace Capnp.Rpc } readonly Task? _proxyTask; - readonly Task _capTask; + readonly StrictlyOrderedAwaitTask _capTask; public LazyCapability(Task capabilityTask) { - _capTask = capabilityTask; + _capTask = capabilityTask.EnforceAwaitOrder(); } public LazyCapability(Task proxyTask) @@ -31,7 +32,7 @@ namespace Capnp.Rpc async Task AwaitCap() => (await _proxyTask!).ConsumedCap; - _capTask = AwaitCap(); + _capTask = AwaitCap().EnforceAwaitOrder(); } internal override Action? Export(IRpcEndpoint endpoint, CapDescriptor.WRITER writer) @@ -61,11 +62,13 @@ namespace Capnp.Rpc } } - public Task WhenResolved => _capTask; + async Task AwaitWhenResolved() => await _capTask; + + public Task WhenResolved => AwaitWhenResolved(); public T? GetResolvedCapability() where T: class { - if (_capTask.IsCompleted) + if (_capTask.WrappedTask.IsCompleted) { try { diff --git a/Capnp.Net.Runtime/Rpc/PolySkeleton.cs b/Capnp.Net.Runtime/Rpc/PolySkeleton.cs index c7762dd..cf08a03 100644 --- a/Capnp.Net.Runtime/Rpc/PolySkeleton.cs +++ b/Capnp.Net.Runtime/Rpc/PolySkeleton.cs @@ -25,8 +25,9 @@ namespace Capnp.Rpc if (skeleton == null) throw new ArgumentNullException(nameof(skeleton)); - skeleton.Claim(); _ifmap.Add(interfaceId, skeleton); + if (_ifmap.Count == 1) // Claiming only the first one is sufficient + skeleton.Claim(); } internal void AddInterface(Skeleton skeleton) @@ -60,8 +61,6 @@ namespace Capnp.Rpc { cap.Relinquish(); } - - base.Dispose(disposing); } internal override void Bind(object impl) diff --git a/Capnp.Net.Runtime/Rpc/TcpRpcClient.cs b/Capnp.Net.Runtime/Rpc/TcpRpcClient.cs index 6b872fd..55771ad 100644 --- a/Capnp.Net.Runtime/Rpc/TcpRpcClient.cs +++ b/Capnp.Net.Runtime/Rpc/TcpRpcClient.cs @@ -151,24 +151,20 @@ namespace Capnp.Rpc /// Bootstrap capability interface /// A proxy for the bootstrap capability /// Not connected - public TProxy GetMain() where TProxy: class + public TProxy GetMain() where TProxy: class, IDisposable { if (WhenConnected == null) { throw new InvalidOperationException("Not connecting"); } - if (!WhenConnected.IsCompleted) + async Task GetMainAsync() { - throw new InvalidOperationException("Connection not yet established"); + await WhenConnected!; + return (CapabilityReflection.CreateProxy(_inboundEndpoint!.QueryMain()) as TProxy)!; } - if (!WhenConnected.ReplacementTaskIsCompletedSuccessfully()) - { - throw new InvalidOperationException("Connection not successfully established"); - } - - return (CapabilityReflection.CreateProxy(_inboundEndpoint!.QueryMain()) as TProxy)!; + return GetMainAsync().Eager(true); } /// diff --git a/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs b/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs new file mode 100644 index 0000000..714ee2f --- /dev/null +++ b/Capnp.Net.Runtime/Util/StrictlyOrderedAwaitTask.cs @@ -0,0 +1,88 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Capnp.Util +{ + internal class StrictlyOrderedAwaitTask: INotifyCompletion + { + readonly Task _awaitedTask; + object _lock; + long _inOrder, _outOrder; + + public StrictlyOrderedAwaitTask(Task awaitedTask) + { + _awaitedTask = awaitedTask; + _lock = new object(); + } + + public StrictlyOrderedAwaitTask GetAwaiter() + { + return this; + } + + public async void OnCompleted(Action continuation) + { + object safeLock = Volatile.Read(ref _lock); + + if (safeLock == null) + { + continuation(); + return; + } + + long sequence = Interlocked.Increment(ref _inOrder) - 1; + + try + { + if (_awaitedTask.IsCompleted) + { + Interlocked.Exchange(ref _lock, null); + } + + await _awaitedTask; + } + catch + { + } + finally + { + SpinWait.SpinUntil(() => + { + lock (safeLock) + { + if (Volatile.Read(ref _outOrder) != sequence) + { + return false; + } + + Interlocked.Increment(ref _outOrder); + + continuation(); + + return true; + } + }); + } + } + + public bool IsCompleted => Volatile.Read(ref _lock) == null; + + public T GetResult() => _awaitedTask.GetAwaiter().GetResult(); + + public T Result => _awaitedTask.Result; + + public Task WrappedTask => _awaitedTask; + } + + internal static class StrictlyOrderedTaskExtensions + { + public static StrictlyOrderedAwaitTask EnforceAwaitOrder(this Task task) + { + return new StrictlyOrderedAwaitTask(task); + } + } +}