diff --git a/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj b/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj index eb11e5e..bc841b7 100644 --- a/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj +++ b/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj @@ -9,11 +9,11 @@ - + @@ -30,6 +30,7 @@ + diff --git a/Capnp.Net.Runtime.Tests/Interception.cs b/Capnp.Net.Runtime.Tests/Interception.cs new file mode 100644 index 0000000..ea7c1f5 --- /dev/null +++ b/Capnp.Net.Runtime.Tests/Interception.cs @@ -0,0 +1,624 @@ +using Capnp.Net.Runtime.Tests.GenImpls; +using Capnp.Rpc; +using Capnp.Rpc.Interception; +using Capnproto_test.Capnp.Test; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Dataflow; + +namespace Capnp.Net.Runtime.Tests +{ + [TestClass] + public class Interception: TestBase + { + class MyPolicy : IInterceptionPolicy + { + readonly string _id; + readonly BufferBlock _callSubject = new BufferBlock(); + readonly BufferBlock _returnSubject = new BufferBlock(); + + public MyPolicy(string id) + { + _id = id; + } + + public bool Equals(IInterceptionPolicy other) + { + return other is MyPolicy myPolicy && _id.Equals(myPolicy._id); + } + + public override bool Equals(object obj) + { + return obj is IInterceptionPolicy other && Equals(other); + } + + public override int GetHashCode() + { + return _id.GetHashCode(); + } + + public void OnCallFromAlice(CallContext callContext) + { + Assert.IsTrue(_callSubject.Post(callContext)); + } + + public void OnReturnFromBob(CallContext callContext) + { + Assert.IsTrue(_returnSubject.Post(callContext)); + } + + public IReceivableSourceBlock Calls => _callSubject; + public IReceivableSourceBlock Returns => _returnSubject; + } + + [TestMethod] + public void InterceptServerSideObserveCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = policy.Attach(new TestInterfaceImpl(counters)); + using (var main = client.GetMain()) + { + var request1 = main.Foo(123, true, default); + var fcc = policy.Calls.ReceiveAsync(); + Assert.IsTrue(fcc.Wait(MediumNonDbgTimeout)); + var cc = fcc.Result; + + var pr = new Capnproto_test.Capnp.Test.TestInterface.Params_foo.READER(cc.InArgs); + Assert.AreEqual(123u, pr.I); + + cc.ForwardToBob(); + + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + var rr = new Capnproto_test.Capnp.Test.TestInterface.Result_foo.READER(cc.OutArgs); + Assert.AreEqual("foo", rr.X); + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.Wait(MediumNonDbgTimeout)); + + Assert.AreEqual("foo", request1.Result); + } + } + } + + [TestMethod] + public void InterceptClientSideModifyCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + + Assert.AreEqual(InterceptionState.RequestedFromAlice, cc.State); + + var pr = new Capnproto_test.Capnp.Test.TestInterface.Params_foo.READER(cc.InArgs); + Assert.AreEqual(321u, pr.I); + Assert.AreEqual(false, pr.J); + + var pw = cc.InArgs.Rewrap(); + pw.I = 123u; + pw.J = true; + + cc.ForwardToBob(); + + var rx = policy.Returns.ReceiveAsync(); + + // Racing against Bob's answer + Assert.IsTrue(cc.State == InterceptionState.ForwardedToBob || rx.IsCompleted); + + Assert.IsTrue(rx.Wait(MediumNonDbgTimeout)); + var rr = new Capnproto_test.Capnp.Test.TestInterface.Result_foo.READER(cc.OutArgs); + Assert.AreEqual("foo", rr.X); + + Assert.IsFalse(request1.IsCompleted); + + var rw = ((DynamicSerializerState)cc.OutArgs).Rewrap(); + rw.X = "bar"; + cc.OutArgs = rw; + + Assert.AreEqual(InterceptionState.ReturnedFromBob, cc.State); + cc.ReturnToAlice(); + Assert.AreEqual(InterceptionState.ReturnedToAlice, cc.State); + + Assert.IsTrue(request1.IsCompleted); + + Assert.AreEqual("bar", request1.Result); + } + } + + } + + [TestMethod] + public void InterceptClientSideShortCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + + var rw = SerializerState.CreateForRpc(); + rw.X = "bar"; + cc.OutArgs = rw; + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + + Assert.AreEqual("bar", request1.Result); + } + } + } + + [TestMethod] + public void InterceptClientSideRejectCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + + cc.Exception = "rejected"; + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + Assert.IsTrue(request1.IsFaulted); + Assert.AreEqual("rejected", request1.Exception.InnerException.Message); + } + } + } + + [TestMethod] + public void InterceptClientSideCancelReturn() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + Assert.IsFalse(cc.CancelFromAlice.IsCancellationRequested); + + cc.ReturnCanceled = true; + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + Assert.IsTrue(request1.IsCanceled); + } + } + } + + [TestMethod] + public void InterceptClientSideOverrideCanceledCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, new CancellationToken(true)); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + Assert.IsTrue(cc.CancelFromAlice.IsCancellationRequested); + + cc.ForwardToBob(); + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + Assert.IsTrue(cc.ReturnCanceled); + cc.ReturnCanceled = false; + cc.Exception = "Cancelled"; + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + Assert.IsTrue(request1.IsFaulted); + } + } + } + + [TestMethod] + public void InterceptClientSideRedirectCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(123, true, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + + var counters2 = new Counters(); + var impl2 = new TestInterfaceImpl(counters2); + cc.Bob = impl2; + cc.ForwardToBob(); + + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + Assert.AreEqual("foo", request1.Result); + Assert.AreEqual(0, counters.CallCount); + Assert.AreEqual(1, counters2.CallCount); + } + } + } + + [TestMethod] + public void InterfaceAndMethodId() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var baz = main.Baz(new TestAllTypes()); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + + Assert.IsTrue(cc.MethodId == 2); + Assert.AreEqual(new TestInterface_Skeleton().InterfaceId, cc.InterfaceId); + } + } + } + + [TestMethod] + public void TailCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestTailCallerImpl(counters); + using (var main = client.GetMain()) + { + var calleeCallCount = new Counters(); + var callee = policy.Attach(new TestTailCalleeImpl(calleeCallCount)); + + var promise = main.Foo(456, callee, default); + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + var cc = ccf.Result; + cc.ForwardToBob(); + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + cc.ReturnToAlice(); + Assert.IsTrue(promise.Wait(MediumNonDbgTimeout)); + Assert.AreEqual("from TestTailCaller", promise.Result.T); + } + } + } + + [TestMethod] + public void InterceptInCaps() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestMoreStuffImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var counters2 = new Counters(); + var cap = new TestInterfaceImpl(counters2); + var promise = main.CallFoo(cap); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.InterceptInCaps(); + cc.ForwardToBob(); + var cc2f = policy.Calls.ReceiveAsync(); + Assert.IsTrue(cc2f.Wait(MediumNonDbgTimeout)); + var cc2 = cc2f.Result; + cc2.ForwardToBob(); + var cc2fr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(cc2fr.Wait(MediumNonDbgTimeout)); + Assert.AreSame(cc2, cc2fr.Result); + Assert.AreEqual(1, counters2.CallCount); + cc2.ReturnToAlice(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + Assert.AreSame(cc, ccfr.Result); + cc.ReturnToAlice(); + Assert.IsTrue(promise.Wait(MediumNonDbgTimeout)); + } + } + } + + [TestMethod] + public void InterceptOutCaps() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = policy.Attach(new TestMoreStuffImpl(counters)); + using (var main = client.GetMain()) + { + var counters2 = new Counters(); + var cap = new TestInterfaceImpl(counters2); + main.Hold(cap); + { + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + ccf.Result.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + ccf.Result.ReturnToAlice(); + } + + var ghf = main.GetHeld(); + { + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + ccf.Result.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + ccf.Result.InterceptOutCaps(); + ccf.Result.ReturnToAlice(); + } + + Assert.IsTrue(ghf.Wait(MediumNonDbgTimeout)); + var held = ghf.Result; + + var foof = held.Foo(123, true); + { + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + ccf.Result.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + ccf.Result.ReturnToAlice(); + } + + Assert.IsTrue(foof.Wait(MediumNonDbgTimeout)); + } + } + } + + [TestMethod] + public void UninterceptOutCaps() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestMoreStuffImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var counters2 = new Counters(); + var cap = new TestInterfaceImpl(counters2); + main.Hold(cap); + { + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.InterceptInCaps(); + cc.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + cc.ReturnToAlice(); + } + + main.CallHeld(); + { + // CallHeld + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.ForwardToBob(); + + // actual call on held cap. + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + ccf.Result.ForwardToBob(); + + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + ccf.Result.ReturnToAlice(); + + ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + cc.ReturnToAlice(); + } + + var ghf = main.GetHeld(); + { + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.InterceptInCaps(); + cc.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + cc.UninterceptOutCaps(); + cc.ReturnToAlice(); + } + + Assert.IsTrue(ghf.Wait(MediumNonDbgTimeout)); + var held = ghf.Result; + + var foof = held.Foo(123, true); + Assert.IsTrue(foof.Wait(MediumNonDbgTimeout)); + } + } + } + + [TestMethod] + public void UninterceptInCaps() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestMoreStuffImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var counters2 = new Counters(); + var cap = policy.Attach(new TestInterfaceImpl(counters2)); + + var foof = main.CallFoo(cap); + + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.UninterceptInCaps(); + cc.ForwardToBob(); + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + cc.ReturnToAlice(); + Assert.IsTrue(foof.Wait(MediumNonDbgTimeout)); + } + } + } + + [TestMethod] + public void MultiAttachAndDetach() + { + var a = new MyPolicy("a"); + var b = new MyPolicy("b"); + var c = new MyPolicy("c"); + + var counters = new Counters(); + var impl = new TestInterfaceImpl(counters); + + var implA = a.Attach(impl); + var implAbc = b.Attach(a.Attach(b.Attach(c.Attach(implA)))); + var implAc = b.Detach(implAbc); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + server.Main = implAc; + using (var main = client.GetMain()) + { + var foof = main.Foo(123, true); + + var ccf1 = c.Calls.ReceiveAsync(); + Assert.IsTrue(ccf1.Wait(MediumNonDbgTimeout)); + var cc1 = ccf1.Result; + cc1.ForwardToBob(); + + var ccf2 = a.Calls.ReceiveAsync(); + Assert.IsTrue(ccf2.Wait(MediumNonDbgTimeout)); + var cc2 = ccf2.Result; + cc2.ForwardToBob(); + + Assert.IsTrue(a.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + cc2.ReturnToAlice(); + + Assert.IsTrue(c.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + cc1.ReturnToAlice(); + + Assert.IsTrue(foof.Wait(MediumNonDbgTimeout)); + } + } + } + + } +} diff --git a/Capnp.Net.Runtime.Tests/ProvidedCapabilityMultiCallMock.cs b/Capnp.Net.Runtime.Tests/ProvidedCapabilityMultiCallMock.cs index 292c678..58bb57a 100644 --- a/Capnp.Net.Runtime.Tests/ProvidedCapabilityMultiCallMock.cs +++ b/Capnp.Net.Runtime.Tests/ProvidedCapabilityMultiCallMock.cs @@ -8,16 +8,16 @@ namespace Capnp.Net.Runtime.Tests { class ProvidedCapabilityMultiCallMock : Skeleton { - readonly BufferBlock _ccs = new BufferBlock(); + readonly BufferBlock _ccs = new BufferBlock(); public override Task Invoke(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken cancellationToken = default(CancellationToken)) { - var cc = new CallContext(interfaceId, methodId, args, cancellationToken); + var cc = new TestCallContext(interfaceId, methodId, args, cancellationToken); Assert.IsTrue(_ccs.Post(cc)); return cc.Result.Task; } - public Task WhenCalled => _ccs.ReceiveAsync(); + public Task WhenCalled => _ccs.ReceiveAsync(); } } diff --git a/Capnp.Net.Runtime.Tests/TcpRpc.cs b/Capnp.Net.Runtime.Tests/TcpRpc.cs index fef274c..b5482e4 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpc.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpc.cs @@ -141,7 +141,7 @@ namespace Capnp.Net.Runtime.Tests var args = DynamicSerializerState.CreateForRpc(); args.SetStruct(1, 0); args.WriteData(0, 123456); - using (var answer = main.Call(0x1234567812345678, 0x3333, args, false)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; @@ -182,7 +182,7 @@ namespace Capnp.Net.Runtime.Tests var args = DynamicSerializerState.CreateForRpc(); args.SetStruct(1, 0); args.WriteData(0, 123456); - using (var answer = main.Call(0x1234567812345678, 0x3333, args, false)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; @@ -223,7 +223,7 @@ namespace Capnp.Net.Runtime.Tests args.SetStruct(1, 0); args.WriteData(0, 123456); CancellationToken ctx; - using (var answer = main.Call(0x1234567812345678, 0x3333, args, false)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; @@ -267,7 +267,7 @@ namespace Capnp.Net.Runtime.Tests args.WriteData(0, 123456); CancellationToken ctx; IPromisedAnswer answer; - using (answer = main.Call(0x1234567812345678, 0x3333, args, false)) + using (answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; @@ -324,7 +324,7 @@ namespace Capnp.Net.Runtime.Tests var args = DynamicSerializerState.CreateForRpc(); args.SetStruct(1, 0); args.WriteData(0, 123456); - using (var answer = main.Call(0x1234567812345678, 0x3333, args, false)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; @@ -361,7 +361,7 @@ namespace Capnp.Net.Runtime.Tests var args = DynamicSerializerState.CreateForRpc(); args.SetStruct(1, 0); args.WriteData(0, 123456); - using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); @@ -372,7 +372,7 @@ namespace Capnp.Net.Runtime.Tests args2.SetStruct(1, 0); args2.WriteData(0, 654321); - using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2, false)) + using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2)) { (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; Assert.AreEqual(0x1234567812345678, interfaceId); @@ -434,7 +434,7 @@ namespace Capnp.Net.Runtime.Tests var args = DynamicSerializerState.CreateForRpc(); args.SetStruct(1, 0); args.WriteData(0, 123456); - using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); @@ -464,7 +464,7 @@ namespace Capnp.Net.Runtime.Tests args2.SetStruct(1, 0); args2.WriteData(0, 654321); - using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2, false)) + using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2)) { Assert.IsTrue(answer.WhenReturned.Wait(MediumTimeout)); Assert.IsTrue(mock2.WhenCalled.Wait(MediumTimeout)); @@ -510,7 +510,7 @@ namespace Capnp.Net.Runtime.Tests var args = DynamicSerializerState.CreateForRpc(); args.SetStruct(1, 0); args.WriteData(0, 123456); - using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); @@ -524,8 +524,8 @@ namespace Capnp.Net.Runtime.Tests args3.SetStruct(1, 0); args3.WriteData(0, 222222); - using (var answer2 = pipelined.Call(0x1111111111111111, 0x1111, args2, false)) - using (var answer3 = pipelined.Call(0x2222222222222222, 0x2222, args3, false)) + using (var answer2 = pipelined.Call(0x1111111111111111, 0x1111, args2)) + using (var answer3 = pipelined.Call(0x2222222222222222, 0x2222, args3)) { (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; @@ -555,8 +555,8 @@ namespace Capnp.Net.Runtime.Tests args5.SetStruct(1, 0); args5.WriteData(0, 444444); - using (var answer4 = pipelined.Call(0x3333333333333333, 0x3333, args4, false)) - using (var answer5 = pipelined.Call(0x4444444444444444, 0x4444, args5, false)) + using (var answer4 = pipelined.Call(0x3333333333333333, 0x3333, args4)) + using (var answer5 = pipelined.Call(0x4444444444444444, 0x4444, args5)) { var call2 = mock2.WhenCalled; var call3 = mock2.WhenCalled; @@ -628,7 +628,7 @@ namespace Capnp.Net.Runtime.Tests args.SetStruct(1, 0); args.WriteData(0, 123456); BareProxy pipelined; - using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); @@ -643,7 +643,7 @@ namespace Capnp.Net.Runtime.Tests try { - pipelined.Call(0x8765432187654321, 0x4444, args2, false); + pipelined.Call(0x8765432187654321, 0x4444, args2); Assert.Fail("Expected an exception here"); } catch (ObjectDisposedException) @@ -675,7 +675,7 @@ namespace Capnp.Net.Runtime.Tests args.SetStruct(1, 0); args.WriteData(0, 123456); IPromisedAnswer answer2; - using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) + using (var answer = main.Call(0x1234567812345678, 0x3333, args)) { Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); @@ -685,7 +685,7 @@ namespace Capnp.Net.Runtime.Tests args2.SetStruct(1, 0); args2.WriteData(0, 654321); - answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2, false); + answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2); } using (answer2) diff --git a/Capnp.Net.Runtime.Tests/CallContext.cs b/Capnp.Net.Runtime.Tests/TestCallContext.cs similarity index 80% rename from Capnp.Net.Runtime.Tests/CallContext.cs rename to Capnp.Net.Runtime.Tests/TestCallContext.cs index 07d5329..21dc3d2 100644 --- a/Capnp.Net.Runtime.Tests/CallContext.cs +++ b/Capnp.Net.Runtime.Tests/TestCallContext.cs @@ -4,9 +4,9 @@ using Capnp.Rpc; namespace Capnp.Net.Runtime.Tests { - class CallContext + class TestCallContext { - public CallContext(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken ct) + public TestCallContext(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken ct) { InterfaceId = interfaceId; MethodId = methodId; diff --git a/Capnp.Net.Runtime/DeserializerState.cs b/Capnp.Net.Runtime/DeserializerState.cs index 73f44df..830ff83 100644 --- a/Capnp.Net.Runtime/DeserializerState.cs +++ b/Capnp.Net.Runtime/DeserializerState.cs @@ -47,7 +47,7 @@ namespace Capnp /// /// The capabilities imported from the capability table. Only valid in RPC context. /// - public IReadOnlyList Caps { get; set; } + public IList Caps { get; set; } /// /// Current segment (essentially Segments[CurrentSegmentIndex] /// diff --git a/Capnp.Net.Runtime/Reserializing.cs b/Capnp.Net.Runtime/Reserializing.cs index 8f0c964..a2bbf6c 100644 --- a/Capnp.Net.Runtime/Reserializing.cs +++ b/Capnp.Net.Runtime/Reserializing.cs @@ -25,6 +25,12 @@ namespace Capnp if (to == null) throw new ArgumentNullException(nameof(to)); + if (from.Caps != null && to.Caps != null) + { + to.Caps.Clear(); + to.Caps.AddRange(from.Caps); + } + var ds = to.Rewrap(); IReadOnlyList items; diff --git a/Capnp.Net.Runtime/Rpc/BareProxy.cs b/Capnp.Net.Runtime/Rpc/BareProxy.cs index 8f133bc..db68fc4 100644 --- a/Capnp.Net.Runtime/Rpc/BareProxy.cs +++ b/Capnp.Net.Runtime/Rpc/BareProxy.cs @@ -45,9 +45,9 @@ /// Method arguments /// Whether it is a tail call /// Answer promise - public IPromisedAnswer Call(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall) + public IPromisedAnswer Call(ulong interfaceId, ushort methodId, DynamicSerializerState args) { - return base.Call(interfaceId, methodId, args, tailCall); + return base.Call(interfaceId, methodId, args, default); } } } diff --git a/Capnp.Net.Runtime/Rpc/ConsumedCapability.cs b/Capnp.Net.Runtime/Rpc/ConsumedCapability.cs index 91863f0..8152d44 100644 --- a/Capnp.Net.Runtime/Rpc/ConsumedCapability.cs +++ b/Capnp.Net.Runtime/Rpc/ConsumedCapability.cs @@ -10,7 +10,7 @@ namespace Capnp.Rpc /// public abstract class ConsumedCapability { - internal abstract IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall); + internal abstract IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args); /// /// Request the RPC engine to release this capability from its import table, diff --git a/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs b/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs new file mode 100644 index 0000000..67bb1e6 --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs @@ -0,0 +1,339 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Capnp.Rpc.Interception +{ + + /// + /// Context of an intercepted call. Provides access to parameters and results, + /// and the possibility to redirect the call to some other capability. + /// + public class CallContext + { + class PromisedAnswer : IPromisedAnswer + { + readonly CallContext _callContext; + readonly TaskCompletionSource _futureResult = new TaskCompletionSource(); + readonly CancellationTokenSource _cancelFromAlice = new CancellationTokenSource(); + + public PromisedAnswer(CallContext callContext) + { + _callContext = callContext; + } + + public Task WhenReturned => _futureResult.Task; + public CancellationToken CancelFromAlice => _cancelFromAlice.Token; + + async Task AccessWhenReturned(MemberAccessPath access) + { + await WhenReturned; + return new Proxy(Access(access)); + } + + public ConsumedCapability Access(MemberAccessPath access) + { + if (_futureResult.Task.IsCompleted) + { + try + { + return access.Eval(WhenReturned.Result); + } + catch (AggregateException exception) + { + throw exception.InnerException; + } + } + else + { + return new LazyCapability(AccessWhenReturned(access)); + } + } + + public void Dispose() + { + try + { + _cancelFromAlice.Cancel(); + } + catch (ObjectDisposedException) + { + // May happen when cancellation request from Alice arrives after return. + } + } + + public void Return() + { + try + { + if (_callContext.ReturnCanceled) + { + _futureResult.SetCanceled(); + } + else if (_callContext.Exception != null) + { + _futureResult.SetException(new RpcException(_callContext.Exception)); + } + else + { + _futureResult.SetResult(_callContext.OutArgs); + } + } + finally + { + _cancelFromAlice.Dispose(); + } + } + } + + /// + /// Target interface ID of this call + /// + public ulong InterfaceId { get; } + + /// + /// Target method ID of this call + /// + public ushort MethodId { get; } + + /// + /// Lifecycle state of this call + /// + public InterceptionState State { get; private set; } + + /// + /// Input arguments + /// + public SerializerState InArgs { get; set; } + + /// + /// Output arguments ("return value") + /// + public DeserializerState OutArgs { get; set; } + + /// + /// Exception text, or null if there is no exception + /// + public string Exception { get; set; } + + /// + /// Whether the call should return in canceled state to Alice (the original caller). + /// In case of forwarding () the property is automatically set according + /// to the cancellation state of Bob's answer. However, you may override it: + /// + /// Setting it from 'false' to 'true' means that we pretend Alice a canceled call. + /// If Alice never requested cancellation this will surprise her pretty much. + /// Setting it from 'true' to 'false' overrides an existing cancellation. Since + /// we did not receive any output arguments from Bob (due to the cancellation), you *must* provide + /// either or . + /// + /// + public bool ReturnCanceled { get; set; } + + /// + /// The cancellation token *from Alice* tells us when the original caller resigns from the call. + /// + public CancellationToken CancelFromAlice { get; private set; } + + /// + /// The cancellation token *to Bob* tells the target capability when we resign from the forwarded call. + /// It is initialized with . Override it to achieve different behaviors: + /// E.g. set it to CancellationToken.None for "hiding" any cancellation request from Alice. + /// Set it to new CancellationToken(true) to pretend Bob a cancellation request. + /// + public CancellationToken CancelToBob { get; set; } + + /// + /// Target capability. May be one of the following: + /// + /// Capability interface implementation + /// A -derived object + /// A -derived object + /// A -derived object (low level capability) + /// null + /// + /// + public object Bob + { + get => _bob; + set + { + if (value != _bob) + { + BobProxy?.Dispose(); + BobProxy = null; + + _bob = value; + + switch (value) + { + case Proxy proxy: + BobProxy = proxy; + break; + + case Skeleton skeleton: + BobProxy = CapabilityReflection.CreateProxy( + LocalCapability.Create(skeleton)); + break; + + case ConsumedCapability cap: + BobProxy = CapabilityReflection.CreateProxy(cap); + break; + + case null: + break; + + default: + BobProxy = CapabilityReflection.CreateProxy( + LocalCapability.Create( + Skeleton.GetOrCreateSkeleton(value, false))); + break; + } + } + } + } + + internal Proxy BobProxy { get; private set; } + + readonly CensorCapability _censorCapability; + PromisedAnswer _promisedAnswer; + object _bob; + + internal IPromisedAnswer Answer => _promisedAnswer; + + internal CallContext(CensorCapability censorCapability, ulong interfaceId, ushort methodId, SerializerState inArgs) + { + _censorCapability = censorCapability; + _promisedAnswer = new PromisedAnswer(this); + + CancelFromAlice = _promisedAnswer.CancelFromAlice; + CancelToBob = CancelFromAlice; + Bob = censorCapability.InterceptedCapability; + InterfaceId = interfaceId; + MethodId = methodId; + InArgs = inArgs; + State = InterceptionState.RequestedFromAlice; + } + + static void InterceptCaps(DeserializerState state, IInterceptionPolicy policy) + { + if (state.Caps != null) + { + for (int i = 0; i < state.Caps.Count; i++) + { + state.Caps[i] = policy.Attach(state.Caps[i]); + state.Caps[i].AddRef(); + } + } + } + + static void UninterceptCaps(DeserializerState state, IInterceptionPolicy policy) + { + if (state.Caps != null) + { + for (int i = 0; i < state.Caps.Count; i++) + { + state.Caps[i] = policy.Detach(state.Caps[i]); + state.Caps[i].AddRef(); + } + } + } + + /// + /// Intercepts all capabilies inside the input arguments + /// + /// Policy to use, or null to further use present policy + public void InterceptInCaps(IInterceptionPolicy policyOverride = null) + { + InterceptCaps(InArgs, policyOverride ?? _censorCapability.Policy); + } + + /// + /// Intercepts all capabilies inside the output arguments + /// + /// Policy to use, or null to further use present policy + public void InterceptOutCaps(IInterceptionPolicy policyOverride = null) + { + InterceptCaps(OutArgs, policyOverride ?? _censorCapability.Policy); + } + + /// + /// Unintercepts all capabilies inside the input arguments + /// + /// Policy to remove, or null to remove present policy + public void UninterceptInCaps(IInterceptionPolicy policyOverride = null) + { + UninterceptCaps(InArgs, policyOverride ?? _censorCapability.Policy); + } + + /// + /// Unintercepts all capabilies inside the output arguments + /// + /// Policy to remove, or null to remove present policy + public void UninterceptOutCaps(IInterceptionPolicy policyOverride = null) + { + UninterceptCaps(OutArgs, policyOverride ?? _censorCapability.Policy); + } + + /// + /// Forwards this intercepted call to the target capability ("Bob"). + /// + /// Optional cancellation token, requesting Bob to cancel the call + public void ForwardToBob() + { + if (Bob == null) + { + throw new InvalidOperationException("Bob is null"); + } + + var answer = BobProxy.Call(InterfaceId, MethodId, InArgs.Rewrap(), default, CancelToBob); + + State = InterceptionState.ForwardedToBob; + + async void ChangeStateWhenReturned() + { + using (answer) + { + try + { + OutArgs = await answer.WhenReturned; + } + catch (TaskCanceledException) + { + ReturnCanceled = true; + } + catch (System.Exception exception) + { + Exception = exception.Message; + } + } + + State = InterceptionState.ReturnedFromBob; + + _censorCapability.Policy.OnReturnFromBob(this); + } + + ChangeStateWhenReturned(); + } + + /// + /// Returns this intercepted call to the caller ("Alice"). + /// + public void ReturnToAlice() + { + try + { + _promisedAnswer.Return(); + } + catch (InvalidOperationException) + { + throw new InvalidOperationException("The call was already returned"); + } + + State = InterceptionState.ReturnedToAlice; + } + } +} diff --git a/Capnp.Net.Runtime/Rpc/Interception/CensorCapability.cs b/Capnp.Net.Runtime/Rpc/Interception/CensorCapability.cs new file mode 100644 index 0000000..abaa526 --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/CensorCapability.cs @@ -0,0 +1,44 @@ +namespace Capnp.Rpc.Interception +{ + class CensorCapability : RefCountingCapability + { + public CensorCapability(ConsumedCapability interceptedCapability, IInterceptionPolicy policy) + { + InterceptedCapability = interceptedCapability; + interceptedCapability.AddRef(); + Policy = policy; + MyVine = Vine.Create(this); + } + + public ConsumedCapability InterceptedCapability { get; } + public IInterceptionPolicy Policy { get; } + internal Skeleton MyVine { get; } + + protected override void ReleaseRemotely() + { + InterceptedCapability.Release(); + } + + internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args) + { + var cc = new CallContext(this, interfaceId, methodId, args); + Policy.OnCallFromAlice(cc); + return cc.Answer; + } + + internal override void Export(IRpcEndpoint endpoint, CapDescriptor.WRITER writer) + { + writer.which = CapDescriptor.WHICH.SenderHosted; + writer.SenderHosted = endpoint.AllocateExport(MyVine, out bool _); + } + + internal override void Freeze(out IRpcEndpoint boundEndpoint) + { + boundEndpoint = null; + } + + internal override void Unfreeze() + { + } + } +} diff --git a/Capnp.Net.Runtime/Rpc/Interception/IInterceptionPolicy.cs b/Capnp.Net.Runtime/Rpc/Interception/IInterceptionPolicy.cs new file mode 100644 index 0000000..a4d5706 --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/IInterceptionPolicy.cs @@ -0,0 +1,23 @@ +using System; + +namespace Capnp.Rpc.Interception +{ + /// + /// An interception policy implements callbacks for outgoing calls and returning forwarded calls. + /// + public interface IInterceptionPolicy: IEquatable + { + /// + /// A caller ("Alice") initiated a new call, which is now intercepted. + /// + /// Context object + void OnCallFromAlice(CallContext callContext); + + /// + /// Given that the intercepted call was forwarded, it returned now from the target ("Bob") + /// and may (or may not) be returned to the original caller ("Alice"). + /// + /// + void OnReturnFromBob(CallContext callContext); + } +} diff --git a/Capnp.Net.Runtime/Rpc/Interception/InterceptionState.cs b/Capnp.Net.Runtime/Rpc/Interception/InterceptionState.cs new file mode 100644 index 0000000..67eea0e --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/InterceptionState.cs @@ -0,0 +1,28 @@ +namespace Capnp.Rpc.Interception +{ + /// + /// The state of an intercepted call from Alice to Bob. + /// + public enum InterceptionState + { + /// + /// Alice initiated the call, but it was neither forwarded to Bob nor finished. + /// + RequestedFromAlice, + + /// + /// The call was forwarded to Bob. + /// + ForwardedToBob, + + /// + /// The call returned from Bob (to whom it was forwarded), but no result was yet forwarded to Alice. + /// + ReturnedFromBob, + + /// + /// The call was returned to Alice (either with results, exception, or cancelled) + /// + ReturnedToAlice + } +} diff --git a/Capnp.Net.Runtime/Rpc/Interception/Interceptor.cs b/Capnp.Net.Runtime/Rpc/Interception/Interceptor.cs new file mode 100644 index 0000000..91c3ee3 --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/Interceptor.cs @@ -0,0 +1,115 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; + +namespace Capnp.Rpc.Interception +{ + /// + /// This static class provides extension methods for intercepting and unintercepting capabilities. + /// + public static class Interceptor + { + static readonly ConditionalWeakTable _interceptMap = + new ConditionalWeakTable(); + + /// + /// Attach this policy to given capability. + /// + /// Capability interface type + /// Policy to attach + /// Capability to censor + /// Censored capability instance + /// is null or + /// is null + public static TCap Attach(this IInterceptionPolicy policy, TCap cap) + where TCap: class + { + if (policy == null) + throw new ArgumentNullException(nameof(policy)); + + if (cap == null) + throw new ArgumentNullException(nameof(cap)); + + var cur = cap as CensorCapability; + + while (cur != null) + { + if (policy.Equals(cur.Policy)) + { + return cap; + } + + cur = cur.InterceptedCapability as CensorCapability; + } + + switch (cap) + { + case Proxy proxy: + return CapabilityReflection.CreateProxy(Attach(policy, proxy.ConsumedCap)) as TCap; + + case ConsumedCapability ccap: + return new CensorCapability(ccap, policy) as TCap; + + default: + return Attach(policy, + CapabilityReflection.CreateProxy( + LocalCapability.Create( + Skeleton.GetOrCreateSkeleton(cap, false))) as TCap); + } + } + + /// + /// Detach this policy from given (censored) capability. + /// + /// Capability interface type + /// Policy to detach + /// Capability to clean + /// Clean capability instance (at least, without this interception policy) + /// is null or + /// is null + public static TCap Detach(this IInterceptionPolicy policy, TCap cap) + where TCap: class + { + if (policy == null) + throw new ArgumentNullException(nameof(policy)); + + if (cap == null) + throw new ArgumentNullException(nameof(cap)); + + switch (cap) + { + case Proxy proxy: + return CapabilityReflection.CreateProxy(Detach(policy, proxy.ConsumedCap)) as TCap; + + case CensorCapability ccap: + { + var cur = ccap; + var stk = new Stack(); + + do + { + if (policy.Equals(cur.Policy)) + { + var cur2 = cur.InterceptedCapability; + + foreach (var p in stk) + { + cur2 = p.Attach(cur2); + } + return cur2 as TCap; + } + + stk.Push(cur.Policy); + cur = cur.InterceptedCapability as CensorCapability; + } + while (cur != null); + + return ccap as TCap; + } + + default: + return cap; + } + } + } +} diff --git a/Capnp.Net.Runtime/Rpc/LazyCapability.cs b/Capnp.Net.Runtime/Rpc/LazyCapability.cs index 607b4a6..ed2ac40 100644 --- a/Capnp.Net.Runtime/Rpc/LazyCapability.cs +++ b/Capnp.Net.Runtime/Rpc/LazyCapability.cs @@ -81,9 +81,7 @@ namespace Capnp.Rpc public Task WhenResolved { get; } - async Task CallImpl(ulong interfaceId, ushort methodId, - DynamicSerializerState args, bool pipeline, - CancellationToken cancellationToken) + async Task CallImpl(ulong interfaceId, ushort methodId, DynamicSerializerState args, CancellationToken cancellationToken) { var cap = await WhenResolved; @@ -92,7 +90,7 @@ namespace Capnp.Rpc if (cap == null) throw new RpcException("Broken capability"); - var call = cap.Call(interfaceId, methodId, args, pipeline); + var call = cap.Call(interfaceId, methodId, args, default); var whenReturned = call.WhenReturned; using (var registration = cancellationToken.Register(call.Dispose)) @@ -101,10 +99,10 @@ namespace Capnp.Rpc } } - internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) + internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args) { var cts = new CancellationTokenSource(); - return new LocalAnswer(cts, CallImpl(interfaceId, methodId, args, pipeline, cts.Token)); + return new LocalAnswer(cts, CallImpl(interfaceId, methodId, args, cts.Token)); } } } diff --git a/Capnp.Net.Runtime/Rpc/LocalAnswerCapability.cs b/Capnp.Net.Runtime/Rpc/LocalAnswerCapability.cs index b4b5e2b..3dbf83b 100644 --- a/Capnp.Net.Runtime/Rpc/LocalAnswerCapability.cs +++ b/Capnp.Net.Runtime/Rpc/LocalAnswerCapability.cs @@ -57,9 +57,7 @@ namespace Capnp.Rpc } } - async Task CallImpl(ulong interfaceId, ushort methodId, - DynamicSerializerState args, bool pipeline, - CancellationToken cancellationToken) + async Task CallImpl(ulong interfaceId, ushort methodId, DynamicSerializerState args, CancellationToken cancellationToken) { var cap = await AwaitResolved(); @@ -68,7 +66,7 @@ namespace Capnp.Rpc if (cap == null) throw new RpcException("Broken capability"); - var call = cap.Call(interfaceId, methodId, args, pipeline); + var call = cap.Call(interfaceId, methodId, args, default); var whenReturned = call.WhenReturned; using (var registration = cancellationToken.Register(() => call.Dispose())) @@ -77,10 +75,10 @@ namespace Capnp.Rpc } } - internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) + internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args) { var cts = new CancellationTokenSource(); - return new LocalAnswer(cts, CallImpl(interfaceId, methodId, args, pipeline, cts.Token)); + return new LocalAnswer(cts, CallImpl(interfaceId, methodId, args, cts.Token)); } protected override void ReleaseRemotely() diff --git a/Capnp.Net.Runtime/Rpc/LocalCapability.cs b/Capnp.Net.Runtime/Rpc/LocalCapability.cs index 60782fc..3328a5f 100644 --- a/Capnp.Net.Runtime/Rpc/LocalCapability.cs +++ b/Capnp.Net.Runtime/Rpc/LocalCapability.cs @@ -42,7 +42,7 @@ namespace Capnp.Rpc ProvidedCap.Relinquish(); } - internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) + internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args) { var cts = new CancellationTokenSource(); var call = ProvidedCap.Invoke(interfaceId, methodId, args, cts.Token); diff --git a/Capnp.Net.Runtime/Rpc/PendingAnswer.cs b/Capnp.Net.Runtime/Rpc/PendingAnswer.cs index 2ef55af..4031501 100644 --- a/Capnp.Net.Runtime/Rpc/PendingAnswer.cs +++ b/Capnp.Net.Runtime/Rpc/PendingAnswer.cs @@ -38,40 +38,6 @@ namespace Capnp.Rpc } } - //public Task WhenReady => ChainedAwaitWhenReady(); - - //public void Pipeline(PromisedAnswer.READER rd, Action action, Action error) - //{ - // lock (_reentrancyBlocker) - // { - // if (_chainedTask == null) - // { - // _chainedTask = InitialAwaitWhenReady(); - // } - - // _chainedTask = _chainedTask.ContinueWith(t => - // { - // bool rethrow = true; - - // try - // { - // t.Wait(); - // rethrow = false; - // EvaluateProxyAndCallContinuation(rd, action); - // } - // catch (AggregateException aggregateException) - // { - // var innerException = aggregateException.InnerException; - - // error(innerException); - - // if (rethrow) throw innerException; - // } - // }, - // TaskContinuationOptions.ExecuteSynchronously); - // } - //} - async Task AwaitChainedTask(Task chainedTask, Func, Task> func) { try @@ -197,83 +163,8 @@ namespace Capnp.Rpc }); } - //Task ChainedAwaitWhenReady() - //{ - // async Task AwaitChainedTask(Task chainedTask) - // { - // await chainedTask; - // return _callTask.Result; - // } - - // Task resultTask; - - // lock (_reentrancyBlocker) - // { - // if (_chainedTask == null) - // { - // _chainedTask = InitialAwaitWhenReady(); - // } - - // resultTask = AwaitChainedTask(_chainedTask); - // _chainedTask = resultTask; - // } - - // return resultTask; - //} - public CancellationToken CancellationToken => _cts?.Token ?? CancellationToken.None; - //void EvaluateProxyAndCallContinuation(PromisedAnswer.READER rd, Action action) - //{ - // var result = _callTask.Result; - - // DeserializerState cur = result; - - // foreach (var op in rd.Transform) - // { - // switch (op.which) - // { - // case PromisedAnswer.Op.WHICH.GetPointerField: - // try - // { - // cur = cur.StructReadPointer(op.GetPointerField); - // } - // catch (System.Exception) - // { - // throw new ArgumentOutOfRangeException("Illegal pointer field in transformation operation"); - // } - // break; - - // case PromisedAnswer.Op.WHICH.Noop: - // break; - - // default: - // throw new ArgumentOutOfRangeException("Unknown transformation operation"); - // } - // } - - // Proxy proxy; - - // switch (cur.Kind) - // { - // case ObjectKind.Capability: - // try - // { - // var cap = result.MsgBuilder.Caps[(int)cur.CapabilityIndex]; - // proxy = new Proxy(cap ?? LazyCapability.Null); - // } - // catch (ArgumentOutOfRangeException) - // { - // throw new ArgumentOutOfRangeException("Bad capability table in internal answer - internal error?"); - // } - // action(proxy); - // break; - - // default: - // throw new ArgumentOutOfRangeException("Transformation did not result in a capability"); - // } - //} - public async void Dispose() { if (_cts != null) diff --git a/Capnp.Net.Runtime/Rpc/PromisedCapability.cs b/Capnp.Net.Runtime/Rpc/PromisedCapability.cs index 87718a7..b704c7e 100644 --- a/Capnp.Net.Runtime/Rpc/PromisedCapability.cs +++ b/Capnp.Net.Runtime/Rpc/PromisedCapability.cs @@ -169,13 +169,13 @@ namespace Capnp.Rpc wr.ImportedCap = _remoteId; } - internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) + internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args) { lock (_reentrancyBlocker) { if (_resolvedCap.Task.IsCompleted) { - return CallOnResolution(interfaceId, methodId, args, pipeline); + return CallOnResolution(interfaceId, methodId, args); } else { @@ -184,7 +184,7 @@ namespace Capnp.Rpc } } - var promisedAnswer = base.DoCall(interfaceId, methodId, args, pipeline); + var promisedAnswer = base.DoCall(interfaceId, methodId, args); TrackCall(promisedAnswer.WhenReturned); return promisedAnswer; } diff --git a/Capnp.Net.Runtime/Rpc/Proxy.cs b/Capnp.Net.Runtime/Rpc/Proxy.cs index c0516c6..1574c5d 100644 --- a/Capnp.Net.Runtime/Rpc/Proxy.cs +++ b/Capnp.Net.Runtime/Rpc/Proxy.cs @@ -66,13 +66,15 @@ namespace Capnp.Rpc /// Interface ID to call /// Method ID to call /// Method arguments ("param struct") - /// Whether it is a tail call + /// This flag is ignored. It is there to preserve compatibility with the + /// code generator and will be removed in future versions. /// For cancelling an ongoing method call /// An answer promise /// This instance was disposed, or transport-layer stream was disposed. /// Capability is broken. /// An I/O error occurs. - protected internal IPromisedAnswer Call(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall, CancellationToken cancellationToken = default) + protected internal IPromisedAnswer Call(ulong interfaceId, ushort methodId, DynamicSerializerState args, + bool obsoleteAndIgnored, CancellationToken cancellationToken = default) { if (_disposedValue) throw new ObjectDisposedException(nameof(Proxy)); @@ -80,7 +82,7 @@ namespace Capnp.Rpc if (ConsumedCap == null) throw new InvalidOperationException("Cannot call null capability"); - var answer = ConsumedCap.DoCall(interfaceId, methodId, args, tailCall); + var answer = ConsumedCap.DoCall(interfaceId, methodId, args); if (cancellationToken.CanBeCanceled) { diff --git a/Capnp.Net.Runtime/Rpc/RemoteAnswerCapability.cs b/Capnp.Net.Runtime/Rpc/RemoteAnswerCapability.cs index e192a15..27f8961 100644 --- a/Capnp.Net.Runtime/Rpc/RemoteAnswerCapability.cs +++ b/Capnp.Net.Runtime/Rpc/RemoteAnswerCapability.cs @@ -99,7 +99,7 @@ namespace Capnp.Rpc _access.Serialize(wr.PromisedAnswer); } - internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) + internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args) { lock (_question.ReentrancyBlocker) { @@ -111,7 +111,7 @@ namespace Capnp.Rpc throw new RpcException("Answer did not resolve to expected capability"); } - return CallOnResolution(interfaceId, methodId, args, pipeline); + return CallOnResolution(interfaceId, methodId, args); } else { @@ -130,7 +130,7 @@ namespace Capnp.Rpc _question.DisallowFinish(); ++_pendingCallsOnPromise; - var promisedAnswer = base.DoCall(interfaceId, methodId, args, pipeline); + var promisedAnswer = base.DoCall(interfaceId, methodId, args); ReAllowFinishWhenDone(promisedAnswer.WhenReturned); async void DecrementPendingCallsOnPromiseWhenReturned() diff --git a/Capnp.Net.Runtime/Rpc/RemoteCapability.cs b/Capnp.Net.Runtime/Rpc/RemoteCapability.cs index e78d7c5..2dc470e 100644 --- a/Capnp.Net.Runtime/Rpc/RemoteCapability.cs +++ b/Capnp.Net.Runtime/Rpc/RemoteCapability.cs @@ -15,7 +15,7 @@ namespace Capnp.Rpc _ep = ep; } - internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall) + internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args) { var call = SetupMessage(args, interfaceId, methodId); Debug.Assert(call.Target.which != MessageTarget.WHICH.undefined); diff --git a/Capnp.Net.Runtime/Rpc/RemoteResolvingCapability.cs b/Capnp.Net.Runtime/Rpc/RemoteResolvingCapability.cs index 65550f0..4c8c9fb 100644 --- a/Capnp.Net.Runtime/Rpc/RemoteResolvingCapability.cs +++ b/Capnp.Net.Runtime/Rpc/RemoteResolvingCapability.cs @@ -29,7 +29,7 @@ namespace Capnp.Rpc protected abstract void GetMessageTarget(MessageTarget.WRITER wr); - protected IPromisedAnswer CallOnResolution(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) + protected IPromisedAnswer CallOnResolution(ulong interfaceId, ushort methodId, DynamicSerializerState args) { try { @@ -62,7 +62,7 @@ namespace Capnp.Rpc #if DebugEmbargos Logger.LogDebug("Direct call"); #endif - return ResolvedCap.Call(interfaceId, methodId, args, pipeline); + return ResolvedCap.Call(interfaceId, methodId, args, default); } else { @@ -90,7 +90,7 @@ namespace Capnp.Rpc cancellationTokenSource.Token.ThrowIfCancellationRequested(); - return ResolvedCap.Call(interfaceId, methodId, args, pipeline); + return ResolvedCap.Call(interfaceId, methodId, args, default); }, TaskContinuationOptions.ExecuteSynchronously); diff --git a/Capnp.Net.Runtime/Rpc/RpcEngine.cs b/Capnp.Net.Runtime/Rpc/RpcEngine.cs index d9e2f0a..316e048 100644 --- a/Capnp.Net.Runtime/Rpc/RpcEngine.cs +++ b/Capnp.Net.Runtime/Rpc/RpcEngine.cs @@ -1311,7 +1311,7 @@ namespace Capnp.Rpc } } - public IReadOnlyList ImportCapTable(Payload.READER payload) + public IList ImportCapTable(Payload.READER payload) { var list = new List(); diff --git a/Capnp.Net.Runtime/Rpc/Skeleton.cs b/Capnp.Net.Runtime/Rpc/Skeleton.cs index 2be39c9..856f2a1 100644 --- a/Capnp.Net.Runtime/Rpc/Skeleton.cs +++ b/Capnp.Net.Runtime/Rpc/Skeleton.cs @@ -39,6 +39,9 @@ namespace Capnp.Rpc internal static Skeleton GetOrCreateSkeleton(T impl, bool addRef) where T: class { + if (impl == null) + throw new ArgumentNullException(nameof(impl)); + if (impl is Skeleton skel) return skel; diff --git a/Capnp.Net.Runtime/Rpc/Vine.cs b/Capnp.Net.Runtime/Rpc/Vine.cs index feb3bde..47bc10e 100644 --- a/Capnp.Net.Runtime/Rpc/Vine.cs +++ b/Capnp.Net.Runtime/Rpc/Vine.cs @@ -32,7 +32,7 @@ namespace Capnp.Rpc ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken cancellationToken = default) { - var promisedAnswer = Proxy.Call(interfaceId, methodId, (DynamicSerializerState)args, false); + var promisedAnswer = Proxy.Call(interfaceId, methodId, (DynamicSerializerState)args, default); if (promisedAnswer is PendingQuestion pendingQuestion && pendingQuestion.RpcEndpoint == Impatient.AskingEndpoint) {