From 2369b4788a60ed64a6b029d1b07858a46f08a2d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20K=C3=B6llner?= Date: Sun, 29 Mar 2020 00:07:16 +0100 Subject: [PATCH] fixes + new tests --- Capnp.Net.Runtime.Tests/Dtbdct.cs | 46 +- ...pcErrorHandling.cs => EdgeCaseHandling.cs} | 580 +++++++++++++++++- Capnp.Net.Runtime.Tests/LocalRpc.cs | 8 +- .../Mock/TestCapImplementations.cs | 109 ++++ Capnp.Net.Runtime.Tests/TcpRpc.cs | 30 +- .../TcpRpcAdvancedStuff.cs | 6 +- Capnp.Net.Runtime.Tests/TcpRpcInterop.cs | 50 +- Capnp.Net.Runtime.Tests/TcpRpcPorted.cs | 2 +- Capnp.Net.Runtime.Tests/Testsuite.cs | 183 +++++- Capnp.Net.Runtime.Tests/Util/TestBase.cs | 19 +- Capnp.Net.Runtime/Rpc/Impatient.cs | 23 +- Capnp.Net.Runtime/Rpc/LazyCapability.cs | 3 +- Capnp.Net.Runtime/Rpc/PendingQuestion.cs | 3 + Capnp.Net.Runtime/Rpc/Proxy.cs | 20 +- .../Rpc/ResolvingCapabilityExtensions.cs | 14 +- Capnp.Net.Runtime/Rpc/RpcEngine.cs | 322 +++++----- Capnp.Net.Runtime/SerializerState.cs | 8 +- 17 files changed, 1170 insertions(+), 256 deletions(-) rename Capnp.Net.Runtime.Tests/{TcpRpcErrorHandling.cs => EdgeCaseHandling.cs} (50%) diff --git a/Capnp.Net.Runtime.Tests/Dtbdct.cs b/Capnp.Net.Runtime.Tests/Dtbdct.cs index eedecb3..fa48434 100644 --- a/Capnp.Net.Runtime.Tests/Dtbdct.cs +++ b/Capnp.Net.Runtime.Tests/Dtbdct.cs @@ -10,9 +10,15 @@ namespace Capnp.Net.Runtime.Tests public class Dtbdct: TestBase { [TestMethod] - public void Embargo() + public void EmbargoOnPromisedAnswer() { - NewDtbdctTestbed().RunTest(Testsuite.Embargo); + NewDtbdctTestbed().RunTest(Testsuite.EmbargoOnPromisedAnswer); + } + + [TestMethod] + public void EmbargoOnImportedCap() + { + NewDtbdctTestbed().RunTest(Testsuite.EmbargoOnImportedCap); } [TestMethod] @@ -63,6 +69,18 @@ namespace Capnp.Net.Runtime.Tests NewDtbdctTestbed().RunTest(Testsuite.PromiseResolve); } + [TestMethod] + public void PromiseResolveLate() + { + NewDtbdctTestbed().RunTest(Testsuite.PromiseResolveLate); + } + + [TestMethod] + public void PromiseResolveError() + { + NewDtbdctTestbed().RunTest(Testsuite.PromiseResolveError); + } + [TestMethod] public void Cancelation() { @@ -116,5 +134,29 @@ namespace Capnp.Net.Runtime.Tests { NewDtbdctTestbed().RunTest(Testsuite.Ownership3); } + + [TestMethod] + public void SillySkeleton() + { + NewDtbdctTestbed().RunTest(Testsuite.SillySkeleton); + } + + [TestMethod] + public void ImportReceiverAnswer() + { + NewDtbdctTestbed().RunTest(Testsuite.ImportReceiverAnswer); + } + + [TestMethod] + public void ImportReceiverAnswerError() + { + NewDtbdctTestbed().RunTest(Testsuite.ImportReceiverAnswerError); + } + + [TestMethod] + public void ImportReceiverAnswerCanceled() + { + NewDtbdctTestbed().RunTest(Testsuite.ImportReceiverCanceled); + } } } diff --git a/Capnp.Net.Runtime.Tests/TcpRpcErrorHandling.cs b/Capnp.Net.Runtime.Tests/EdgeCaseHandling.cs similarity index 50% rename from Capnp.Net.Runtime.Tests/TcpRpcErrorHandling.cs rename to Capnp.Net.Runtime.Tests/EdgeCaseHandling.cs index 8bd35b2..b399f0a 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcErrorHandling.cs +++ b/Capnp.Net.Runtime.Tests/EdgeCaseHandling.cs @@ -16,7 +16,7 @@ namespace Capnp.Net.Runtime.Tests { [TestClass] [TestCategory("Coverage")] - public class TcpRpcErrorHandling: TestBase + public class EdgeCaseHandling: TestBase { class MemStreamEndpoint : IEndpoint { @@ -255,6 +255,74 @@ namespace Capnp.Net.Runtime.Tests tester.ExpectAbort(); } + [TestMethod] + public void UnimplementedReturnAcceptFromThirdParty() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.AcceptFromThirdParty; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + } + + [TestMethod] + public void UnimplementedReturnUnknown() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = (Return.WHICH)33; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + } + + [TestMethod] + public void InvalidReturnTakeFromOtherQuestion() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.TakeFromOtherQuestion; + _.Return.TakeFromOtherQuestion = 1u; + }); + tester.ExpectAbort(); + } + [TestMethod] public void InvalidReceiverHosted() { @@ -309,7 +377,92 @@ namespace Capnp.Net.Runtime.Tests } [TestMethod] - public void DuplicateResolve() + public void InvalidCallTargetImportedCap() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 0; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 1; + _.Call.Target.which = MessageTarget.WHICH.ImportedCap; + _.Call.Target.ImportedCap = bootCapId + 1; + _.Call.InterfaceId = ((TypeIdAttribute)typeof(ITestInterface).GetCustomAttributes(typeof(TypeIdAttribute), false)[0]).Id; + _.Call.MethodId = 0; + _.Call.Params.Content.Rewrap(); + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void InvalidCallTargetPromisedAnswer() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 0; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 1; + _.Call.Target.which = MessageTarget.WHICH.PromisedAnswer; + _.Call.Target.PromisedAnswer.QuestionId = 1; + _.Call.Target.PromisedAnswer.Transform.Init(1); + _.Call.Target.PromisedAnswer.Transform[0].which = PromisedAnswer.Op.WHICH.GetPointerField; + _.Call.InterfaceId = ((TypeIdAttribute)typeof(ITestInterface).GetCustomAttributes(typeof(TypeIdAttribute), false)[0]).Id; + _.Call.MethodId = 0; + _.Call.Params.Content.Rewrap(); + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void UnimplementedCallTargetUnknown() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 0; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 1; + _.Call.Target.which = (MessageTarget.WHICH)77; + _.Call.InterfaceId = ((TypeIdAttribute)typeof(ITestInterface).GetCustomAttributes(typeof(TypeIdAttribute), false)[0]).Id; + _.Call.MethodId = 0; + _.Call.Params.Content.Rewrap(); + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + Assert.IsFalse(tester.IsDismissed); + } + + [TestMethod] + public void DuplicateResolve1() { var tester = new RpcEngineTester(); @@ -346,6 +499,125 @@ namespace Capnp.Net.Runtime.Tests tester.Recv(_ => { Assert.AreEqual(Message.WHICH.Release, _.which); }); + + // tester.ExpectAbort(); + + // Duplicate resolve is only a protocol error if the Rpc engine can prove misbehavior. + // In this case that proof is not possible because the preliminary cap is release (thus, removed from import table) + // immediately after the first resolution. Now we get the situation that the 2nd resolution refers to a non-existing + // cap. This is not considered a protocol error because it might be due to an expected race condition + // between receiver-side Release and sender-side Resolve. + } + + [TestMethod] + public void DuplicateResolve2() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.SenderPromise; + _.Return.Results.CapTable[0].SenderPromise = 0; + _.Return.Results.Content.SetCapability(0); + }); + proxy.Call(0, 0, DynamicSerializerState.CreateForRpc()); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Finish, _.which); + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Call, _.which); + }); + tester.Send(_ => { + _.which = Message.WHICH.Resolve; + _.Resolve.which = Resolve.WHICH.Cap; + _.Resolve.Cap.which = CapDescriptor.WHICH.SenderHosted; + _.Resolve.Cap.SenderHosted = 1; + }); + tester.Send(_ => { + _.which = Message.WHICH.Resolve; + _.Resolve.which = Resolve.WHICH.Exception; + _.Resolve.Exception.Reason = "problem"; + }); + + tester.ExpectAbort(); + } + + [TestMethod] + public void UnimplementedResolveCategory() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.SenderPromise; + _.Return.Results.CapTable[0].SenderPromise = 0; + }); + tester.Send(_ => { + _.which = Message.WHICH.Resolve; + _.Resolve.which = (Resolve.WHICH)7; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Finish, _.which); + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + } + + [TestMethod] + public void InvalidResolve() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.SenderHosted; + _.Return.Results.CapTable[0].SenderHosted = 7; + }); + tester.Send(_ => { + _.which = Message.WHICH.Resolve; + _.Resolve.which = Resolve.WHICH.Cap; + _.Resolve.PromiseId = 7; + _.Resolve.Cap.which = CapDescriptor.WHICH.SenderHosted; + _.Resolve.Cap.SenderHosted = 1; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Finish, _.which); + }); + tester.ExpectAbort(); } @@ -506,6 +778,73 @@ namespace Capnp.Net.Runtime.Tests Assert.IsFalse(tester.IsDismissed); } + [TestMethod] + public void UnimplementedSendResultsToThirdParty() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { + _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 0; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 42; + _.Call.Target.which = MessageTarget.WHICH.ImportedCap; + _.Call.Target.ImportedCap = bootCapId; + _.Call.InterfaceId = new TestInterface_Skeleton().InterfaceId; + _.Call.MethodId = 0; + var wr = _.Call.Params.Content.Rewrap(); + _.Call.Params.CapTable.Init(0); + _.Call.SendResultsTo.which = Call.sendResultsTo.WHICH.ThirdParty; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + Assert.IsFalse(tester.IsDismissed); + } + + [TestMethod] + public void UnimplementedSendResultsToUnknown() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { + _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 0; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 42; + _.Call.Target.which = MessageTarget.WHICH.ImportedCap; + _.Call.Target.ImportedCap = bootCapId; + _.Call.InterfaceId = new TestInterface_Skeleton().InterfaceId; + _.Call.MethodId = 0; + var wr = _.Call.Params.Content.Rewrap(); + _.Call.Params.CapTable.Init(0); + _.Call.SendResultsTo.which = (Call.sendResultsTo.WHICH)13; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + Assert.IsFalse(tester.IsDismissed); + } + class TestPipelineImpl3 : ITestPipeline { readonly TestPipelineImpl2 _impl; @@ -605,5 +944,242 @@ namespace Capnp.Net.Runtime.Tests }); Assert.IsFalse(tester.IsDismissed); } + + [TestMethod] + public void SenderLoopbackOnInvalidCap() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + tester.Send(_ => { + _.which = Message.WHICH.Disembargo; + _.Disembargo.Target.which = MessageTarget.WHICH.ImportedCap; + _.Disembargo.Target.ImportedCap = 0; + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void SenderLoopbackOnInvalidPromisedAnswer() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + tester.Send(_ => { + _.which = Message.WHICH.Disembargo; + _.Disembargo.Context.which = Disembargo.context.WHICH.SenderLoopback; + _.Disembargo.Context.SenderLoopback = 0; + _.Disembargo.Target.which = MessageTarget.WHICH.PromisedAnswer; + _.Disembargo.Target.PromisedAnswer.QuestionId = 9; + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void SenderLoopbackOnUnknownTarget() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + tester.Send(_ => { + _.which = Message.WHICH.Disembargo; + _.Disembargo.Context.which = Disembargo.context.WHICH.SenderLoopback; + _.Disembargo.Context.SenderLoopback = 0; + _.Disembargo.Target.which = (MessageTarget.WHICH)12; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + } + + [TestMethod] + public void ReceiverLoopbackOnInvalidCap() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + tester.Send(_ => { + _.which = Message.WHICH.Disembargo; + _.Disembargo.Context.which = Disembargo.context.WHICH.ReceiverLoopback; + _.Disembargo.Context.ReceiverLoopback = 0; + _.Disembargo.Target.which = MessageTarget.WHICH.ImportedCap; + _.Disembargo.Target.ImportedCap = 0; + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void UnimplementedDisembargoAccept() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + tester.Send(_ => { + _.which = Message.WHICH.Disembargo; + _.Disembargo.Context.which = Disembargo.context.WHICH.Accept; + _.Disembargo.Target.which = MessageTarget.WHICH.ImportedCap; + _.Disembargo.Target.ImportedCap = 0; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + } + + [TestMethod] + public void UnimplementedDisembargoProvide() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + tester.Send(_ => { + _.which = Message.WHICH.Disembargo; + _.Disembargo.Context.which = Disembargo.context.WHICH.Provide; + _.Disembargo.Target.which = MessageTarget.WHICH.ImportedCap; + _.Disembargo.Target.ImportedCap = 0; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + } + + [TestMethod] + public void UnimplementedDisembargoUnknown() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + tester.Send(_ => { + _.which = Message.WHICH.Disembargo; + _.Disembargo.Context.which = (Disembargo.context.WHICH)50; + _.Disembargo.Target.which = MessageTarget.WHICH.ImportedCap; + _.Disembargo.Target.ImportedCap = 0; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + }); + } + + [TestMethod] + public void UnimplementedCall() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.SenderHosted; + _.Return.Results.CapTable[0].SenderHosted = 1; + _.Return.Results.Content.SetCapability(0); + }); + Assert.IsTrue(proxy.WhenResolved.IsCompleted); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Finish, _.which); + }); + var args = DynamicSerializerState.CreateForRpc(); + var ti = new TestInterfaceImpl(new Counters()); + args.ProvideCapability(ti); + proxy.Call(1, 2, args); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Call, _.which); + Assert.AreEqual(1ul, _.Call.InterfaceId); + Assert.AreEqual((ushort)2, _.Call.MethodId); + + Assert.IsFalse(ti.IsDisposed); + + tester.Send(_1 => + { + _1.which = Message.WHICH.Unimplemented; + _1.Unimplemented.which = Message.WHICH.Call; + Reserializing.DeepCopy(_.Call, _1.Unimplemented.Call); + }); + + Assert.IsTrue(ti.IsDisposed); + }); + } + + [TestMethod] + public void UnimplementedBootstrap() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + + tester.Send(_1 => + { + _1.which = Message.WHICH.Unimplemented; + _1.Unimplemented.which = Message.WHICH.Bootstrap; + Reserializing.DeepCopy(_.Bootstrap, _1.Unimplemented.Bootstrap); + }); + }); + + tester.ExpectAbort(); + } + + [TestMethod] + public void Abort() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + }); + + tester.Send(_ => { + _.which = Message.WHICH.Abort; + }); + } + + [TestMethod] + public void ThirdPartyHostedBootstrap() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + + tester.Send(_1 => + { + _1.which = Message.WHICH.Return; + _1.Return.AnswerId = _.Bootstrap.QuestionId; + _1.Return.which = Return.WHICH.Results; + _1.Return.Results.CapTable.Init(1); + _1.Return.Results.CapTable[0].which = CapDescriptor.WHICH.ThirdPartyHosted; + _1.Return.Results.CapTable[0].ThirdPartyHosted.VineId = 27; + _1.Return.Results.Content.SetCapability(0); + }); + }); + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Finish, _.which); + }); + + proxy.Call(1, 2, DynamicSerializerState.CreateForRpc()); + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Call, _.which); + Assert.AreEqual(MessageTarget.WHICH.ImportedCap, _.Call.Target.which); + Assert.AreEqual(27u, _.Call.Target.ImportedCap); + }); + } } } diff --git a/Capnp.Net.Runtime.Tests/LocalRpc.cs b/Capnp.Net.Runtime.Tests/LocalRpc.cs index 2100293..7adc180 100644 --- a/Capnp.Net.Runtime.Tests/LocalRpc.cs +++ b/Capnp.Net.Runtime.Tests/LocalRpc.cs @@ -34,7 +34,7 @@ namespace Capnp.Net.Runtime.Tests [TestMethod] public void Embargo() { - NewLocalTestbed().RunTest(Testsuite.Embargo); + NewLocalTestbed().RunTest(Testsuite.EmbargoOnPromisedAnswer); } [TestMethod] @@ -132,5 +132,11 @@ namespace Capnp.Net.Runtime.Tests { NewLocalTestbed().RunTest(Testsuite.Ownership3); } + + [TestMethod] + public void ImportReceiverAnswer() + { + NewLocalTestbed().RunTest(Testsuite.Ownership3); + } } } diff --git a/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs b/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs index 6f65819..dbd91c5 100644 --- a/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs +++ b/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs @@ -827,6 +827,115 @@ namespace Capnp.Net.Runtime.Tests.GenImpls } } + class TestMoreStuffImpl2 : ITestMoreStuff + { + readonly TaskCompletionSource _echo = new TaskCompletionSource(); + readonly TaskCompletionSource _held = new TaskCompletionSource(); + ITestCallOrder _cap; + int _callCount; + + public TestMoreStuffImpl2() + { + } + + public async Task CallFoo(ITestInterface cap, CancellationToken cancellationToken_) + { + using (cap) + { + string s = await cap.Foo(123, true, cancellationToken_); + Assert.AreEqual("foo", s); + } + return "bar"; + } + + public Task CallFooWhenResolved(ITestInterface cap, CancellationToken cancellationToken_) + { + throw new NotImplementedException(); + } + + public Task CallHeld(CancellationToken cancellationToken_) + { + throw new NotImplementedException(); + } + + public void Dispose() + { + } + + public Task Echo(ITestCallOrder cap, CancellationToken cancellationToken_) + { + _cap = cap; + return Task.FromResult(_echo.Task.Eager(true)); + } + + public void EnableEcho() + { + _echo.SetResult(_cap); + } + + public Task ExpectCancel(ITestInterface cap, CancellationToken cancellationToken_) + { + throw new NotImplementedException(); + } + + public Task GetCallSequence(uint expected, CancellationToken cancellationToken_) + { + return Task.FromResult((uint)(Interlocked.Increment(ref _callCount) - 1)); + } + + public Task GetEnormousString(CancellationToken cancellationToken_) + { + return Task.FromResult(new string(new char[100000000])); + } + + public Task GetHandle(CancellationToken cancellationToken_) + { + throw new NotImplementedException(); + } + + public Task GetHeld(CancellationToken cancellationToken_) + { + return _held.Task; + } + + public Task GetNull(CancellationToken cancellationToken_) + { + return Task.FromResult(default(ITestMoreStuff)); + } + + public async Task Hold(ITestInterface cap, CancellationToken cancellationToken_) + { + try + { + var unwrapped = await cap.Unwrap(); + _held.SetResult(unwrapped); + } + catch (System.Exception exception) when (exception.Message == new TaskCanceledException().Message) + { + _held.SetCanceled(); + } + catch (System.Exception exception) + { + _held.SetException(exception); + } + } + + public Task<(string, string)> MethodWithDefaults(string a, uint b, string c, CancellationToken cancellationToken_) + { + throw new NotImplementedException(); + } + + public Task MethodWithNullDefault(string a, ITestInterface b, CancellationToken cancellationToken_) + { + throw new NotImplementedException(); + } + + public Task NeverReturn(ITestInterface cap, CancellationToken cancellationToken_) + { + throw new NotImplementedException(); + } + } + #endregion TestMoreStuff #region TestHandle diff --git a/Capnp.Net.Runtime.Tests/TcpRpc.cs b/Capnp.Net.Runtime.Tests/TcpRpc.cs index 56d8229..b122da4 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpc.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpc.cs @@ -44,7 +44,7 @@ namespace Capnp.Net.Runtime.Tests int MediumTimeout => Debugger.IsAttached ? Timeout.Infinite : 2000; - [TestMethod, Timeout(10000)] + [TestMethod] public void CreateAndDispose() { (var server, var client) = SetupClientServerPair(); @@ -55,7 +55,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void ConnectAndDispose() { (var server, var client) = SetupClientServerPair(); @@ -77,7 +77,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void ConnectNoServer() { using (var client = new TcpRpcClient("localhost", TcpPort)) @@ -86,7 +86,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void ConnectAndBootstrap() { (var server, var client) = SetupClientServerPair(); @@ -105,7 +105,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void ConnectNoBootstrap() { (var server, var client) = SetupClientServerPair(); @@ -123,7 +123,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void CallReturn() { (var server, var client) = SetupClientServerPair(); @@ -164,7 +164,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void CallCancelOnServer() { (var server, var client) = SetupClientServerPair(); @@ -199,7 +199,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void CallCancelOnClient() { ExpectingLogOutput = false; @@ -244,7 +244,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void CallReturnAfterClientSideCancel() { ExpectingLogOutput = false; @@ -306,7 +306,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void CallServerSideException() { (var server, var client) = SetupClientServerPair(); @@ -343,7 +343,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void PipelineBeforeReturn() { (var server, var client) = SetupClientServerPair(); @@ -416,7 +416,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void PipelineAfterReturn() { (var server, var client) = SetupClientServerPair(); @@ -492,7 +492,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void PipelineMultiple() { (var server, var client) = SetupClientServerPair(); @@ -609,7 +609,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void PipelineCallAfterDisposal() { (var server, var client) = SetupClientServerPair(); @@ -656,7 +656,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void PipelineCallDuringDisposal() { (var server, var client) = SetupClientServerPair(); diff --git a/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs b/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs index e029d0a..ecd5d8d 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs @@ -13,7 +13,7 @@ namespace Capnp.Net.Runtime.Tests [TestCategory("Coverage")] public class TcpRpcAdvancedStuff : TestBase { - [TestMethod, Timeout(10000)] + [TestMethod] public void MultiConnect() { using (var server = SetupServer()) @@ -51,7 +51,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void TwoClients() { using (var server = SetupServer()) @@ -90,7 +90,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void ClosingServerWhileRequestingBootstrap() { for (int i = 0; i < 100; i++) diff --git a/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs b/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs index 97228ca..e093a27 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs @@ -128,7 +128,7 @@ namespace Capnp.Net.Runtime.Tests IncrementTcpPort(); } - [TestMethod, Timeout(10000)] + [TestMethod] public void BasicClient() { LaunchCompatTestProcess("server:Interface", stdout => @@ -159,7 +159,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void BasicServer() { using (var server = SetupServer()) @@ -176,7 +176,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void PipelineClient() { LaunchCompatTestProcess("server:Pipeline", stdout => @@ -213,7 +213,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void PipelineServer() { using (var server = SetupServer()) @@ -232,7 +232,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void ReleaseClient() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -265,7 +265,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void ReleaseServer() { using (var server = SetupServer()) @@ -377,7 +377,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void ReleaseOnCancelServer() { using (var server = SetupServer()) @@ -464,7 +464,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void CancelationServer() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -494,7 +494,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void CancelationClient() { using (var server = SetupServer()) @@ -511,7 +511,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void PromiseResolveServer() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -553,7 +553,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void PromiseResolveClient() { using (var server = SetupServer()) @@ -573,7 +573,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void RetainAndReleaseServer() { var destructionPromise = new TaskCompletionSource(); @@ -671,7 +671,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void CancelServer() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -712,7 +712,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void CancelClient() { using (var server = SetupServer()) @@ -729,7 +729,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void SendTwiceServer() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -773,7 +773,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void SendTwiceClient() { using (var server = SetupServer()) @@ -793,7 +793,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void EmbargoServer() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -878,7 +878,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void EmbargoServer2() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -961,7 +961,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void EmbargoClient() { using (var server = SetupServer()) @@ -1031,7 +1031,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void EmbargoErrorServer() { LaunchCompatTestProcess("server:MoreStuff", EmbargoErrorImpl); @@ -1049,7 +1049,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void EmbargoErrorClient() { using (var server = SetupServer()) @@ -1065,7 +1065,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void EmbargoNullServer() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -1125,7 +1125,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void EmbargoNullClient() { using (var server = SetupServer()) @@ -1141,7 +1141,7 @@ namespace Capnp.Net.Runtime.Tests } } - [TestMethod, Timeout(10000)] + [TestMethod] public void CallBrokenPromiseServer() { LaunchCompatTestProcess("server:MoreStuff", stdout => @@ -1178,7 +1178,7 @@ namespace Capnp.Net.Runtime.Tests }); } - [TestMethod, Timeout(10000)] + [TestMethod] public void CallBrokenPromiseClient() { using (var server = SetupServer()) diff --git a/Capnp.Net.Runtime.Tests/TcpRpcPorted.cs b/Capnp.Net.Runtime.Tests/TcpRpcPorted.cs index e878920..9e7cd2a 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcPorted.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcPorted.cs @@ -119,7 +119,7 @@ namespace Capnp.Net.Runtime.Tests [TestMethod] public void Embargo() { - NewLocalhostTcpTestbed().RunTest(Testsuite.Embargo); + NewLocalhostTcpTestbed().RunTest(Testsuite.EmbargoOnPromisedAnswer); } [TestMethod] diff --git a/Capnp.Net.Runtime.Tests/Testsuite.cs b/Capnp.Net.Runtime.Tests/Testsuite.cs index 1b3486d..90b4675 100644 --- a/Capnp.Net.Runtime.Tests/Testsuite.cs +++ b/Capnp.Net.Runtime.Tests/Testsuite.cs @@ -53,7 +53,7 @@ namespace Capnp.Net.Runtime.Tests ftask.GetAwaiter().GetResult(); // re-throw exception } - public static void Embargo(ITestbed testbed) + public static void EmbargoOnPromisedAnswer(ITestbed testbed) { var counters = new Counters(); var impl = new TestMoreStuffImpl(counters); @@ -113,7 +113,64 @@ namespace Capnp.Net.Runtime.Tests } } } - + + public static void EmbargoOnImportedCap(ITestbed testbed) + { + var impl = new TestMoreStuffImpl2(); + + using (var main = testbed.ConnectMain(impl)) + { + var cap = new TestCallOrderImpl(); + cap.CountToDispose = 6; + + var earlyCall = main.GetCallSequence(0, default); + + var echo = main.Echo(cap, default); + testbed.MustComplete(echo); + using (var pipeline = echo.Result) + { + var call0 = pipeline.GetCallSequence(0, default); + var call1 = pipeline.GetCallSequence(1, default); + + testbed.MustComplete(earlyCall); + + impl.EnableEcho(); + + var call2 = pipeline.GetCallSequence(2, default); + + testbed.MustComplete(echo); + using (var resolved = echo.Result) + { + var call3 = pipeline.GetCallSequence(3, default); + var call4 = pipeline.GetCallSequence(4, default); + var call5 = pipeline.GetCallSequence(5, default); + + try + { + testbed.MustComplete(call0); + testbed.MustComplete(call1); + testbed.MustComplete(call2); + testbed.MustComplete(call3); + testbed.MustComplete(call4); + testbed.MustComplete(call5); + } + catch (System.Exception) + { + cap.CountToDispose = null; + throw; + } + + Assert.AreEqual(0u, call0.Result); + Assert.AreEqual(1u, call1.Result); + Assert.AreEqual(2u, call2.Result); + Assert.AreEqual(3u, call3.Result); + Assert.AreEqual(4u, call4.Result); + Assert.AreEqual(5u, call5.Result); + } + } + } + } + public static void EmbargoError(ITestbed testbed) { var counters = new Counters(); @@ -404,6 +461,55 @@ namespace Capnp.Net.Runtime.Tests } } + public static void PromiseResolveLate(ITestbed testbed) + { + var counters = new Counters(); + var impl = new TestMoreStuffImpl(counters); + using (var main = testbed.ConnectMain(impl)) + { + var tcs = new TaskCompletionSource(); + var disposed = new TaskCompletionSource(); + using (var eager = tcs.Task.Eager(true)) + { + var request = main.NeverReturn(Proxy.Share(eager), new CancellationToken(true)); + + testbed.MustComplete(request); + + var tiimpl = new TestInterfaceImpl(new Counters(), disposed); + tcs.SetResult(tiimpl); + + Assert.IsFalse(tiimpl.IsDisposed); + } + testbed.MustComplete(disposed.Task); + } + } + + public static void PromiseResolveError(ITestbed testbed) + { + var counters = new Counters(); + var impl = new TestMoreStuffImpl(counters); + using (var main = testbed.ConnectMain(impl)) + { + var tcs = new TaskCompletionSource(); + using (var eager = tcs.Task.Eager(true)) + { + var request = main.CallFoo(Proxy.Share(eager), default); + var request2 = main.CallFooWhenResolved(eager, default); + + var gcs = main.GetCallSequence(0, default); + testbed.MustComplete(gcs); + Assert.AreEqual(2u, gcs.Result); + Assert.AreEqual(3, counters.CallCount); + + tcs.SetException(new System.Exception("too bad")); + + testbed.MustComplete(request, request2); + Assert.IsTrue(request.IsFaulted); + Assert.IsTrue(request2.IsFaulted); + } + } + } + public static void Cancelation(ITestbed testbed) { var counters = new Counters(); @@ -570,5 +676,78 @@ namespace Capnp.Net.Runtime.Tests testbed.MustComplete(tcs.Task); } } + + class ThrowingSkeleton : Skeleton + { + public bool WasCalled { get; private set; } + + public override Task Invoke(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken cancellationToken = default) + { + WasCalled = true; + throw new NotImplementedException(); + } + } + + public static void SillySkeleton(ITestbed testbed) + { + var impl = new ThrowingSkeleton(); + using (var main = testbed.ConnectMain(impl)) + { + var tcs = new TaskCompletionSource(); + var ti = new TestInterfaceImpl(new Counters(), tcs); + testbed.ExpectPromiseThrows(main.CallFoo(ti)); + Assert.IsTrue(impl.WasCalled); + testbed.MustComplete(tcs.Task); + } + } + + public static void ImportReceiverAnswer(ITestbed testbed) + { + var impl = new TestMoreStuffImpl2(); + using (var main = testbed.ConnectMain(impl)) + { + var held = main.GetHeld().Eager(); + var foo = main.CallFoo(held); + testbed.MustNotComplete(foo); + var tcs = new TaskCompletionSource(); + testbed.MustComplete( + main.Hold(new TestInterfaceImpl(new Counters(), tcs)), + foo, + tcs.Task); + } + } + + public static void ImportReceiverAnswerError(ITestbed testbed) + { + var impl = new TestMoreStuffImpl2(); + using (var main = testbed.ConnectMain(impl)) + using (var held = main.GetHeld().Eager()) + { + var foo = main.CallFoo(held); + testbed.MustNotComplete(foo); + var faulted = Task.FromException( + new InvalidOperationException("I faulted")).Eager(true); + testbed.MustComplete( + main.Hold(faulted), + foo); + Assert.IsTrue(foo.IsFaulted); + } + } + + public static void ImportReceiverCanceled(ITestbed testbed) + { + var impl = new TestMoreStuffImpl2(); + using (var main = testbed.ConnectMain(impl)) + using (var held = main.GetHeld().Eager()) + { + var foo = main.CallFoo(held); + testbed.MustNotComplete(foo); + var canceled = Task.FromCanceled(new CancellationToken(true)).Eager(true); + testbed.MustComplete( + main.Hold(canceled), + foo); + Assert.IsTrue(foo.IsCanceled); + } + } } } diff --git a/Capnp.Net.Runtime.Tests/Util/TestBase.cs b/Capnp.Net.Runtime.Tests/Util/TestBase.cs index 2edffbc..6352051 100644 --- a/Capnp.Net.Runtime.Tests/Util/TestBase.cs +++ b/Capnp.Net.Runtime.Tests/Util/TestBase.cs @@ -15,7 +15,7 @@ namespace Capnp.Net.Runtime.Tests { public interface ITestbed { - T ConnectMain(T main) where T : class; + T ConnectMain(object main) where T : class; void MustComplete(params Task[] tasks); void MustNotComplete(params Task[] tasks); void FlushCommunication(); @@ -52,8 +52,11 @@ namespace Capnp.Net.Runtime.Tests public void Dismiss() { - OtherEndpoint.Dismiss(); - _dismissed = true; + if (!_dismissed) + { + _dismissed = true; + OtherEndpoint.Dismiss(); + } } public void Forward(WireFrame frame) @@ -156,9 +159,9 @@ namespace Capnp.Net.Runtime.Tests action(this); } - T ITestbed.ConnectMain(T main) + T ITestbed.ConnectMain(object main) { - return main; + return (T)main; } void ITestbed.FlushCommunication() @@ -199,7 +202,7 @@ namespace Capnp.Net.Runtime.Tests }); } - T ITestbed.ConnectMain(T main) + T ITestbed.ConnectMain(object main) { return SetupEnginePair(main, _decisionTree, out _enginePair); } @@ -242,7 +245,7 @@ namespace Capnp.Net.Runtime.Tests } } - T ITestbed.ConnectMain(T main) + T ITestbed.ConnectMain(object main) { _server.Main = main; return _client.GetMain(); @@ -330,7 +333,7 @@ namespace Capnp.Net.Runtime.Tests return (server, client); } - protected static T SetupEnginePair(T main, DecisionTree decisionTree, out EnginePair pair) where T: class + protected static T SetupEnginePair(object main, DecisionTree decisionTree, out EnginePair pair) where T: class { pair = new EnginePair(decisionTree); pair.Engine1.Main = main; diff --git a/Capnp.Net.Runtime/Rpc/Impatient.cs b/Capnp.Net.Runtime/Rpc/Impatient.cs index 7b2c7b3..4b73082 100644 --- a/Capnp.Net.Runtime/Rpc/Impatient.cs +++ b/Capnp.Net.Runtime/Rpc/Impatient.cs @@ -158,8 +158,16 @@ namespace Capnp.Rpc throw new ArgumentException("The task was not returned from a remote method invocation. See documentation for details."); } - var lazyCap = new LazyCapability(task.AsProxyTask()); - return (CapabilityReflection.CreateProxy(lazyCap) as TInterface)!; + var proxyTask = task.AsProxyTask(); + if (proxyTask.ReplacementTaskIsCompletedSuccessfully()) + { + return proxyTask.Result.Cast(true); + } + else + { + var lazyCap = new LazyCapability(proxyTask); + return (CapabilityReflection.CreateProxy(lazyCap) as TInterface)!; + } } else { @@ -172,6 +180,17 @@ namespace Capnp.Rpc } } + public static async Task Unwrap(this TInterface cap) where TInterface: class, IDisposable + { + using var proxy = cap as Proxy; + + if (proxy == null) + return cap; + + var unwrapped = await proxy.ConsumedCap.Unwrap(); + return ((CapabilityReflection.CreateProxy(unwrapped)) as TInterface)!; + } + internal static IRpcEndpoint? AskingEndpoint { get => _askingEndpoint.Value; diff --git a/Capnp.Net.Runtime/Rpc/LazyCapability.cs b/Capnp.Net.Runtime/Rpc/LazyCapability.cs index ae1eef4..6b11347 100644 --- a/Capnp.Net.Runtime/Rpc/LazyCapability.cs +++ b/Capnp.Net.Runtime/Rpc/LazyCapability.cs @@ -64,8 +64,7 @@ namespace Capnp.Rpc if (WhenResolved.ReplacementTaskIsCompletedSuccessfully()) { using var proxy = new Proxy(WhenResolved.Result); - proxy.Export(endpoint, writer); - return null; + return proxy.Export(endpoint, writer); } else { diff --git a/Capnp.Net.Runtime/Rpc/PendingQuestion.cs b/Capnp.Net.Runtime/Rpc/PendingQuestion.cs index b660d6d..394f4f2 100644 --- a/Capnp.Net.Runtime/Rpc/PendingQuestion.cs +++ b/Capnp.Net.Runtime/Rpc/PendingQuestion.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Diagnostics; using System.Threading.Tasks; @@ -89,6 +90,7 @@ namespace Capnp.Rpc internal object ReentrancyBlocker { get; } = new object(); internal uint QuestionId => _questionId; internal State StateFlags { get; private set; } + internal IReadOnlyList? CapTable { get; set; } /// /// Eventually returns the server answer @@ -323,6 +325,7 @@ namespace Capnp.Rpc try { RpcEndpoint.SendQuestion(inParams, call.Params); + CapTable = call.Params.CapTable; } catch (System.Exception exception) { diff --git a/Capnp.Net.Runtime/Rpc/Proxy.cs b/Capnp.Net.Runtime/Rpc/Proxy.cs index 22957a9..65616b0 100644 --- a/Capnp.Net.Runtime/Rpc/Proxy.cs +++ b/Capnp.Net.Runtime/Rpc/Proxy.cs @@ -131,18 +131,17 @@ namespace Capnp.Rpc #endif } - internal Skeleton? GetProvider() + internal async Task GetProvider() { - switch (ConsumedCap) + var unwrapped = await ConsumedCap.Unwrap(); + + switch (unwrapped) { case LocalCapability lcap: return lcap.ProvidedCap; - case null: - return null; - default: - return Vine.Create(ConsumedCap); + return Vine.Create(unwrapped); } } @@ -214,15 +213,20 @@ namespace Capnp.Rpc } } - internal void Export(IRpcEndpoint endpoint, CapDescriptor.WRITER writer) + internal Action? Export(IRpcEndpoint endpoint, CapDescriptor.WRITER writer) { if (_disposedValue) throw new ObjectDisposedException(nameof(Proxy)); if (ConsumedCap == null) + { writer.which = CapDescriptor.WHICH.None; + return null; + } else - ConsumedCap.Export(endpoint, writer); + { + return ConsumedCap.Export(endpoint, writer); + } } internal void Freeze(out IRpcEndpoint? boundEndpoint) diff --git a/Capnp.Net.Runtime/Rpc/ResolvingCapabilityExtensions.cs b/Capnp.Net.Runtime/Rpc/ResolvingCapabilityExtensions.cs index 0461d75..86c5fa8 100644 --- a/Capnp.Net.Runtime/Rpc/ResolvingCapabilityExtensions.cs +++ b/Capnp.Net.Runtime/Rpc/ResolvingCapabilityExtensions.cs @@ -5,6 +5,18 @@ namespace Capnp.Rpc { static class ResolvingCapabilityExtensions { + public static async Task Unwrap(this ConsumedCapability? cap) + { + cap ??= LazyCapability.Null; + + while (cap is IResolvingCapability resolving) + { + cap = await resolving.WhenResolved ?? LazyCapability.Null; + } + + return cap; + } + public static Action? ExportAsSenderPromise(this T cap, IRpcEndpoint endpoint, CapDescriptor.WRITER writer) where T: ConsumedCapability, IResolvingCapability { @@ -20,7 +32,7 @@ namespace Capnp.Rpc try { - var resolvedCap = await cap.WhenResolved; + var resolvedCap = await Unwrap(await cap.WhenResolved); endpoint.Resolve(preliminaryId, vine, () => resolvedCap!); } catch (System.Exception exception) diff --git a/Capnp.Net.Runtime/Rpc/RpcEngine.cs b/Capnp.Net.Runtime/Rpc/RpcEngine.cs index 11c99b5..c59c3d8 100644 --- a/Capnp.Net.Runtime/Rpc/RpcEngine.cs +++ b/Capnp.Net.Runtime/Rpc/RpcEngine.cs @@ -30,16 +30,9 @@ namespace Capnp.Rpc ++RefCount; } - public void Release() - { - --RefCount; - CheckDispose(); - } - public void ReleaseAll() { RefCount = 0; - CheckDispose(); } public void Release(int count) @@ -48,15 +41,6 @@ namespace Capnp.Rpc throw new ArgumentOutOfRangeException(nameof(count)); RefCount -= count; - CheckDispose(); - } - - void CheckDispose() - { - if (RefCount == 0 && Cap is IDisposable disposable) - { - disposable.Dispose(); - } } } @@ -90,7 +74,7 @@ namespace Capnp.Rpc readonly RpcEngine _host; readonly IEndpoint _tx; - readonly Dictionary>> _importTable = new Dictionary>>(); + readonly Dictionary> _importTable = new Dictionary>(); readonly Dictionary> _exportTable = new Dictionary>(); readonly Dictionary _revExportTable = new Dictionary(); readonly Dictionary _questionTable = new Dictionary(); @@ -207,11 +191,17 @@ namespace Capnp.Rpc void SendAbort(string reason) { - var mb = MessageBuilder.Create(); - var msg = mb.BuildRoot(); - msg.which = Message.WHICH.Abort; - msg.Abort!.Reason = reason; - Tx(mb.Frame); + try + { + var mb = MessageBuilder.Create(); + var msg = mb.BuildRoot(); + msg.which = Message.WHICH.Abort; + msg.Abort!.Reason = reason; + Tx(mb.Frame); + } + catch // Take care that an exception does not prevent shutdown. + { + } } void IRpcEndpoint.Resolve(uint preliminaryId, Skeleton preliminaryCap, Func resolvedCapGetter) @@ -305,40 +295,16 @@ namespace Capnp.Rpc lock (_reentrancyBlocker) { uint questionId = NextId(); - var question = new PendingQuestion(this, questionId, target, inParams); - - while (!_questionTable.ReplacementTryAdd(questionId, question)) - { + while (_questionTable.ContainsKey(questionId)) questionId = NextId(); - var oldQuestion = question; - question = new PendingQuestion(this, questionId, target, inParams); - oldQuestion.Dispose(); - } + + var question = new PendingQuestion(this, questionId, target, inParams); + _questionTable.Add(questionId, question); return question; } } - void DeleteQuestion(uint id, PendingQuestion question) - { - lock (_reentrancyBlocker) - { - if (!_questionTable.TryGetValue(id, out var existingQuestion)) - { - Logger.LogError("Attempting to delete unknown question ID. Race condition?"); - return; - } - - if (question != existingQuestion) - { - Logger.LogError("Found different question under given ID. WTF???"); - return; - } - - _questionTable.Remove(id); - } - } - (TaskCompletionSource, uint) AllocateDisembargo() { var tcs = new TaskCompletionSource(); @@ -444,9 +410,9 @@ namespace Capnp.Rpc } } - Skeleton? callTargetCap; + Skeleton callTargetCap; PendingAnswer pendingAnswer; - bool releaseParamCaps = false; + bool releaseParamCaps = true; void AwaitAnswerAndReply() { @@ -564,32 +530,27 @@ namespace Capnp.Rpc { var inParams = req.Params.Content; inParams.Caps = ImportCapTable(req.Params); + releaseParamCaps = false; - if (callTargetCap == null) + try { - releaseParamCaps = true; - pendingAnswer = new PendingAnswer( - Task.FromException( - new RpcException("Call target resolved to null")), null); + var cts = new CancellationTokenSource(); + var callTask = callTargetCap.Invoke(req.InterfaceId, req.MethodId, inParams, cts.Token); + pendingAnswer = new PendingAnswer(callTask, cts); } - else + catch (System.Exception exception) { - try + foreach (var cap in inParams.Caps) { - var cts = new CancellationTokenSource(); - var callTask = callTargetCap.Invoke(req.InterfaceId, req.MethodId, inParams, cts.Token); - pendingAnswer = new PendingAnswer(callTask, cts); - } - catch (System.Exception exception) - { - releaseParamCaps = true; - pendingAnswer = new PendingAnswer( - Task.FromException(exception), null); - } - finally - { - callTargetCap.Relinquish(); + cap?.Release(); } + + pendingAnswer = new PendingAnswer( + Task.FromException(exception), null); + } + finally + { + callTargetCap.Relinquish(); } AwaitAnswerAndReply(); @@ -657,8 +618,8 @@ namespace Capnp.Rpc try { using var proxy = await t; - callTargetCap = proxy?.GetProvider(); - callTargetCap?.Claim(); + callTargetCap = await proxy.GetProvider(); + callTargetCap.Claim(); CreateAnswerAwaitItAndReply(); } catch (TaskCanceledException) @@ -715,6 +676,11 @@ namespace Capnp.Rpc } } + if (req.ReleaseParamCaps) + { + ReleaseExports(question.CapTable); + } + switch (req.which) { case Return.WHICH.Results: @@ -782,6 +748,7 @@ namespace Capnp.Rpc else { Logger.LogWarning("Incoming RPC return: Peer requested to take results from other question, but specified ID is unknown (already released?)"); + throw new RpcProtocolErrorException("Invalid ID"); } } break; @@ -793,48 +760,53 @@ namespace Capnp.Rpc void ProcessResolve(Resolve.READER resolve) { - if (!_importTable.TryGetValue(resolve.PromiseId, out var rcw)) + lock (_reentrancyBlocker) { - Logger.LogWarning("Received a resolve message with invalid ID"); - throw new RpcProtocolErrorException($"Invalid ID {resolve.PromiseId}"); - } - - if (!rcw.Cap.TryGetTarget(out var cap)) - { - // Silently discard - return; - } - - if (!(cap is PromisedCapability resolvableCap)) - { - Logger.LogWarning("Received a resolve message for a capability which is not a promise"); - throw new RpcProtocolErrorException($"Not a promise {resolve.PromiseId}"); - } - - try - { - switch (resolve.which) + if (!_importTable.TryGetValue(resolve.PromiseId, out var rcc)) { - case Resolve.WHICH.Cap: - lock (_reentrancyBlocker) - { + // May happen if Resolve arrives late. Not an actual error. + + if (resolve.which == Resolve.WHICH.Cap) + { + // Import and release immediately + var imp = ImportCap(resolve.Cap); + imp.AddRef(); + imp.Release(); + } + + return; + } + + var cap = rcc.Cap; + + if (!(cap is PromisedCapability resolvableCap)) + { + Logger.LogWarning("Received a resolve message for a capability which is not a promise"); + throw new RpcProtocolErrorException($"Not a promise {resolve.PromiseId}"); + } + + try + { + switch (resolve.which) + { + case Resolve.WHICH.Cap: var resolvedCap = ImportCap(resolve.Cap); resolvableCap.ResolveTo(resolvedCap); - } - break; + break; - case Resolve.WHICH.Exception: - resolvableCap.Break(resolve.Exception.Reason ?? "unknown reason"); - break; + case Resolve.WHICH.Exception: + resolvableCap.Break(resolve.Exception.Reason ?? "unknown reason"); + break; - default: - Logger.LogWarning("Received a resolve message with unknown category."); - throw new RpcUnimplementedException(); + default: + Logger.LogWarning("Received a resolve message with unknown category."); + throw new RpcUnimplementedException(); + } + } + catch (InvalidOperationException) + { + throw new RpcProtocolErrorException($"Capability {resolve.PromiseId} was already resolved"); } - } - catch (InvalidOperationException) - { - throw new RpcProtocolErrorException($"Capability {resolve.PromiseId} was already resolved"); } } @@ -958,14 +930,12 @@ namespace Capnp.Rpc Logger.LogDebug($"Receiver loopback disembargo, Thread = {Thread.CurrentThread.Name}"); #endif - if (!tcs.TrySetResult(0)) - { - Logger.LogError("Attempting to grant disembargo failed. Looks like an internal error / race condition."); - } + tcs.SetResult(0); } else { Logger.LogWarning("Peer sent receiver loopback with unknown ID"); + throw new RpcProtocolErrorException("Invalid ID"); } } @@ -988,6 +958,26 @@ namespace Capnp.Rpc } } + void ReleaseExports(IReadOnlyList? caps) + { + if (caps != null) + { + foreach (var capDesc in caps) + { + switch (capDesc.which) + { + case CapDescriptor.WHICH.SenderHosted: + ReleaseExport(capDesc.SenderHosted, 1); + break; + + case CapDescriptor.WHICH.SenderPromise: + ReleaseExport(capDesc.SenderPromise, 1); + break; + } + } + } + } + void ReleaseResultCaps(PendingAnswer answer) { answer.Chain(async t => @@ -995,24 +985,7 @@ namespace Capnp.Rpc try { var aorcq = await t; - var caps = answer.CapTable; - - if (caps != null) - { - foreach (var capDesc in caps) - { - switch (capDesc.which) - { - case CapDescriptor.WHICH.SenderHosted: - ReleaseExport(capDesc.SenderHosted, 1); - break; - - case CapDescriptor.WHICH.SenderPromise: - ReleaseExport(capDesc.SenderPromise, 1); - break; - } - } - } + ReleaseExports(answer.CapTable); } catch { @@ -1103,6 +1076,7 @@ namespace Capnp.Rpc break; case CapDescriptor.WHICH.SenderPromise: + // Not really expected that a promise gets resolved to another promise. ReleaseExport(resolve.Cap.SenderPromise, 1); break; @@ -1114,7 +1088,19 @@ namespace Capnp.Rpc void ProcessUnimplementedCall(Call.READER call) { - Finish(call.QuestionId); + PendingQuestion? question; + + lock (_reentrancyBlocker) + { + if (!_questionTable.TryGetValue(call.QuestionId, out question)) + { + Logger.LogWarning("Unimplemented call: Unknown question ID."); + + throw new RpcProtocolErrorException("Unknown question ID"); + } + } + + ReleaseExports(question.CapTable); } void ProcessUnimplemented(Message.READER unimplemented) @@ -1265,54 +1251,34 @@ namespace Capnp.Rpc switch (capDesc.which) { case CapDescriptor.WHICH.SenderHosted: - if (_importTable.TryGetValue(capDesc.SenderHosted, out var rcw)) + if (_importTable.TryGetValue(capDesc.SenderHosted, out var rcc)) { - if (rcw.Cap.TryGetTarget(out var impCap)) - { - impCap.Validate(); - rcw.AddRef(); - return impCap; - } - else - { - impCap = new ImportedCapability(this, capDesc.SenderHosted); - rcw.Cap.SetTarget(impCap); - } - - return impCap!; + var impCap = rcc.Cap; + impCap.Validate(); + rcc.AddRef(); + return impCap; } else { var newCap = new ImportedCapability(this, capDesc.SenderHosted); - rcw = new RefCounted>( - new WeakReference(newCap)); - _importTable.Add(capDesc.SenderHosted, rcw); + rcc = new RefCounted(newCap); + _importTable.Add(capDesc.SenderHosted, rcc); return newCap; } case CapDescriptor.WHICH.SenderPromise: - if (_importTable.TryGetValue(capDesc.SenderPromise, out var rcwp)) + if (_importTable.TryGetValue(capDesc.SenderPromise, out var rccp)) { - if (rcwp.Cap.TryGetTarget(out var impCap)) - { - impCap.Validate(); - rcwp.AddRef(); - return impCap; - } - else - { - impCap = new PromisedCapability(this, capDesc.SenderPromise); - rcwp.Cap.SetTarget(impCap); - } - + var impCap = rccp.Cap; + impCap.Validate(); + rccp.AddRef(); return impCap; } else { var newCap = new PromisedCapability(this, capDesc.SenderPromise); - rcw = new RefCounted>( - new WeakReference(newCap)); - _importTable.Add(capDesc.SenderPromise, rcw); + rccp = new RefCounted(newCap); + _importTable.Add(capDesc.SenderPromise, rccp); return newCap; } @@ -1362,25 +1328,15 @@ namespace Capnp.Rpc case CapDescriptor.WHICH.ThirdPartyHosted: if (_importTable.TryGetValue(capDesc.ThirdPartyHosted.VineId, out var rcv)) { - if (rcv.Cap.TryGetTarget(out var impCap)) - { - rcv.AddRef(); - impCap.Validate(); - return impCap; - } - else - { - impCap = new ImportedCapability(this, capDesc.ThirdPartyHosted.VineId); - rcv.Cap.SetTarget(impCap); - } - + var impCap = rcv.Cap; + rcv.AddRef(); + impCap.Validate(); return impCap; } else { var newCap = new ImportedCapability(this, capDesc.ThirdPartyHosted.VineId); - rcv = new RefCounted>( - new WeakReference(newCap)); + rcv = new RefCounted(newCap); return newCap; } @@ -1552,7 +1508,13 @@ namespace Capnp.Rpc void IRpcEndpoint.DeleteQuestion(PendingQuestion question) { - DeleteQuestion(question.QuestionId, question); + lock (_reentrancyBlocker) + { + if (!_questionTable.Remove(question.QuestionId)) + { + Logger.LogError("Attempting to delete unknown question ID."); + } + } } } diff --git a/Capnp.Net.Runtime/SerializerState.cs b/Capnp.Net.Runtime/SerializerState.cs index 8a19bba..42c5522 100644 --- a/Capnp.Net.Runtime/SerializerState.cs +++ b/Capnp.Net.Runtime/SerializerState.cs @@ -511,7 +511,7 @@ namespace Capnp } } - internal void SetCapability(uint capabilityIndex) + public void SetCapability(uint capabilityIndex) { if (Kind == ObjectKind.Nil) { @@ -1259,10 +1259,10 @@ namespace Capnp /// Adds an entry to the capability table if the provided capability does not yet exist. /// /// The capability, in one of the following forms: - /// Low-level capability object (Rpc.ConsumedCapability) - /// Proxy object (Rpc.Proxy). Note that the provision has "move semantics": SerializerState + /// Low-level capability () + /// Proxy object (). Note that the provision has "move semantics": SerializerState /// takes ownership, so the Proxy object will be disposed. - /// Skeleton object (Rpc.Skeleton) + /// instance /// Capability interface implementation /// /// Index of the given capability in the capability table