diff --git a/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj b/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj index eb11e5e..bc841b7 100644 --- a/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj +++ b/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj @@ -9,11 +9,11 @@ - + @@ -30,6 +30,7 @@ + diff --git a/Capnp.Net.Runtime.Tests/Interception.cs b/Capnp.Net.Runtime.Tests/Interception.cs new file mode 100644 index 0000000..cc95e9b --- /dev/null +++ b/Capnp.Net.Runtime.Tests/Interception.cs @@ -0,0 +1,587 @@ +using Capnp.Net.Runtime.Tests.GenImpls; +using Capnp.Rpc; +using Capnp.Rpc.Interception; +using Capnproto_test.Capnp.Test; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Threading.Tasks.Dataflow; + +namespace Capnp.Net.Runtime.Tests +{ + [TestClass] + public class Interception: TestBase + { + class MyPolicy : IInterceptionPolicy + { + readonly string _id; + readonly BufferBlock _callSubject = new BufferBlock(); + readonly BufferBlock _returnSubject = new BufferBlock(); + + public MyPolicy(string id) + { + _id = id; + } + + public bool Equals(IInterceptionPolicy other) + { + return other is MyPolicy myPolicy && _id.Equals(myPolicy._id); + } + + public override bool Equals(object obj) + { + return obj is IInterceptionPolicy other && Equals(other); + } + + public override int GetHashCode() + { + return _id.GetHashCode(); + } + + public void OnCallFromAlice(CallContext callContext) + { + Assert.IsTrue(_callSubject.Post(callContext)); + } + + public void OnReturnFromBob(CallContext callContext) + { + Assert.IsTrue(_returnSubject.Post(callContext)); + } + + public IReceivableSourceBlock Calls => _callSubject; + public IReceivableSourceBlock Returns => _returnSubject; + } + + [TestMethod] + public void InterceptServerSideObserveCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = policy.Attach(new TestInterfaceImpl(counters)); + using (var main = client.GetMain()) + { + var request1 = main.Foo(123, true, default); + var fcc = policy.Calls.ReceiveAsync(); + Assert.IsTrue(fcc.Wait(MediumNonDbgTimeout)); + var cc = fcc.Result; + + var pr = new Capnproto_test.Capnp.Test.TestInterface.Params_foo.READER(cc.InArgs); + Assert.AreEqual(123u, pr.I); + + cc.ForwardToBob(); + + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + var rr = new Capnproto_test.Capnp.Test.TestInterface.Result_foo.READER(cc.OutArgs); + Assert.AreEqual("foo", rr.X); + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.Wait(MediumNonDbgTimeout)); + + Assert.AreEqual("foo", request1.Result); + } + } + } + + [TestMethod] + public void InterceptClientSideModifyCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + + Assert.AreEqual(InterceptionState.RequestedFromAlice, cc.State); + + var pr = new Capnproto_test.Capnp.Test.TestInterface.Params_foo.READER(cc.InArgs); + Assert.AreEqual(321u, pr.I); + Assert.AreEqual(false, pr.J); + + var pw = cc.InArgs.Rewrap(); + pw.I = 123u; + pw.J = true; + + cc.ForwardToBob(); + + var rx = policy.Returns.ReceiveAsync(); + + // Racing against Bob's answer + Assert.IsTrue(cc.State == InterceptionState.ForwardedToBob || rx.IsCompleted); + + Assert.IsTrue(rx.Wait(MediumNonDbgTimeout)); + var rr = new Capnproto_test.Capnp.Test.TestInterface.Result_foo.READER(cc.OutArgs); + Assert.AreEqual("foo", rr.X); + + Assert.IsFalse(request1.IsCompleted); + + var rw = ((DynamicSerializerState)cc.OutArgs).Rewrap(); + rw.X = "bar"; + cc.OutArgs = rw; + + Assert.AreEqual(InterceptionState.ReturnedFromBob, cc.State); + cc.ReturnToAlice(); + Assert.AreEqual(InterceptionState.ReturnedToAlice, cc.State); + + Assert.IsTrue(request1.IsCompleted); + + Assert.AreEqual("bar", request1.Result); + } + } + + } + + [TestMethod] + public void InterceptClientSideShortCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + + var rw = SerializerState.CreateForRpc(); + rw.X = "bar"; + cc.OutArgs = rw; + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + + Assert.AreEqual("bar", request1.Result); + } + } + } + + [TestMethod] + public void InterceptClientSideRejectCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + + cc.Exception = "rejected"; + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + Assert.IsTrue(request1.IsFaulted); + Assert.AreEqual("rejected", request1.Exception.InnerException.Message); + } + } + } + + [TestMethod] + public void InterceptClientSideCancelCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(321, false, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + + cc.IsCanceled = true; + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + Assert.IsTrue(request1.IsCanceled); + } + } + } + + [TestMethod] + public void InterceptClientSideRedirectCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var request1 = main.Foo(123, true, default); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + Assert.IsFalse(request1.IsCompleted); + + var counters2 = new Counters(); + var impl2 = new TestInterfaceImpl(counters2); + cc.Bob = impl2; + cc.ForwardToBob(); + + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + + cc.ReturnToAlice(); + + Assert.IsTrue(request1.IsCompleted); + Assert.AreEqual("foo", request1.Result); + Assert.AreEqual(0, counters.CallCount); + Assert.AreEqual(1, counters2.CallCount); + } + } + } + + [TestMethod] + public void InterfaceAndMethodId() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestInterfaceImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var baz = main.Baz(new TestAllTypes()); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + + Assert.IsTrue(cc.MethodId == 2); + Assert.AreEqual(new TestInterface_Skeleton().InterfaceId, cc.InterfaceId); + } + } + } + + [TestMethod] + public void TailCall() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestTailCallerImpl(counters); + using (var main = client.GetMain()) + { + var calleeCallCount = new Counters(); + var callee = policy.Attach(new TestTailCalleeImpl(calleeCallCount)); + + var promise = main.Foo(456, callee, default); + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + var cc = ccf.Result; + cc.ForwardToBob(); + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + cc.ReturnToAlice(); + Assert.IsTrue(promise.Wait(MediumNonDbgTimeout)); + Assert.AreEqual("from TestTailCaller", promise.Result.T); + } + } + } + + [TestMethod] + public void InterceptInCaps() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestMoreStuffImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var counters2 = new Counters(); + var cap = new TestInterfaceImpl(counters2); + var promise = main.CallFoo(cap); + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.InterceptInCaps(); + cc.ForwardToBob(); + var cc2f = policy.Calls.ReceiveAsync(); + Assert.IsTrue(cc2f.Wait(MediumNonDbgTimeout)); + var cc2 = cc2f.Result; + cc2.ForwardToBob(); + var cc2fr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(cc2fr.Wait(MediumNonDbgTimeout)); + Assert.AreSame(cc2, cc2fr.Result); + Assert.AreEqual(1, counters2.CallCount); + cc2.ReturnToAlice(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + Assert.AreSame(cc, ccfr.Result); + cc.ReturnToAlice(); + Assert.IsTrue(promise.Wait(MediumNonDbgTimeout)); + } + } + } + + [TestMethod] + public void InterceptOutCaps() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = policy.Attach(new TestMoreStuffImpl(counters)); + using (var main = client.GetMain()) + { + var counters2 = new Counters(); + var cap = new TestInterfaceImpl(counters2); + main.Hold(cap); + { + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + ccf.Result.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + ccf.Result.ReturnToAlice(); + } + + var ghf = main.GetHeld(); + { + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + ccf.Result.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + ccf.Result.InterceptOutCaps(); + ccf.Result.ReturnToAlice(); + } + + Assert.IsTrue(ghf.Wait(MediumNonDbgTimeout)); + var held = ghf.Result; + + var foof = held.Foo(123, true); + { + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + ccf.Result.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + ccf.Result.ReturnToAlice(); + } + + Assert.IsTrue(foof.Wait(MediumNonDbgTimeout)); + } + } + } + + [TestMethod] + public void UninterceptOutCaps() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestMoreStuffImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var counters2 = new Counters(); + var cap = new TestInterfaceImpl(counters2); + main.Hold(cap); + { + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.InterceptInCaps(); + cc.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + cc.ReturnToAlice(); + } + + main.CallHeld(); + { + // CallHeld + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.ForwardToBob(); + + // actual call on held cap. + var ccf = policy.Calls.ReceiveAsync(); + Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout)); + ccf.Result.ForwardToBob(); + + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + ccf.Result.ReturnToAlice(); + + ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + cc.ReturnToAlice(); + } + + var ghf = main.GetHeld(); + { + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.InterceptInCaps(); + cc.ForwardToBob(); + var ccfr = policy.Returns.ReceiveAsync(); + Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout)); + cc.UninterceptOutCaps(); + cc.ReturnToAlice(); + } + + Assert.IsTrue(ghf.Wait(MediumNonDbgTimeout)); + var held = ghf.Result; + + var foof = held.Foo(123, true); + Assert.IsTrue(foof.Wait(MediumNonDbgTimeout)); + } + } + } + + [TestMethod] + public void UninterceptInCaps() + { + var policy = new MyPolicy("a"); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + var counters = new Counters(); + server.Main = new TestMoreStuffImpl(counters); + using (var main = policy.Attach(client.GetMain())) + { + var counters2 = new Counters(); + var cap = policy.Attach(new TestInterfaceImpl(counters2)); + + var foof = main.CallFoo(cap); + + Assert.IsTrue(policy.Calls.TryReceive(out var cc)); + cc.UninterceptInCaps(); + cc.ForwardToBob(); + Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + cc.ReturnToAlice(); + Assert.IsTrue(foof.Wait(MediumNonDbgTimeout)); + } + } + } + + [TestMethod] + public void MultiAttachAndDetach() + { + var a = new MyPolicy("a"); + var b = new MyPolicy("b"); + var c = new MyPolicy("c"); + + var counters = new Counters(); + var impl = new TestInterfaceImpl(counters); + + var implA = a.Attach(impl); + var implAbc = b.Attach(a.Attach(b.Attach(c.Attach(implA)))); + var implAc = b.Detach(implAbc); + + (var server, var client) = SetupClientServerPair(); + + using (server) + using (client) + { + client.WhenConnected.Wait(); + + server.Main = implAc; + using (var main = client.GetMain()) + { + var foof = main.Foo(123, true); + + var ccf1 = c.Calls.ReceiveAsync(); + Assert.IsTrue(ccf1.Wait(MediumNonDbgTimeout)); + var cc1 = ccf1.Result; + cc1.ForwardToBob(); + + var ccf2 = a.Calls.ReceiveAsync(); + Assert.IsTrue(ccf2.Wait(MediumNonDbgTimeout)); + var cc2 = ccf2.Result; + cc2.ForwardToBob(); + + Assert.IsTrue(a.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + cc2.ReturnToAlice(); + + Assert.IsTrue(c.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout)); + cc1.ReturnToAlice(); + + Assert.IsTrue(foof.Wait(MediumNonDbgTimeout)); + } + } + } + + } +} diff --git a/Capnp.Net.Runtime.Tests/ProvidedCapabilityMultiCallMock.cs b/Capnp.Net.Runtime.Tests/ProvidedCapabilityMultiCallMock.cs index 292c678..58bb57a 100644 --- a/Capnp.Net.Runtime.Tests/ProvidedCapabilityMultiCallMock.cs +++ b/Capnp.Net.Runtime.Tests/ProvidedCapabilityMultiCallMock.cs @@ -8,16 +8,16 @@ namespace Capnp.Net.Runtime.Tests { class ProvidedCapabilityMultiCallMock : Skeleton { - readonly BufferBlock _ccs = new BufferBlock(); + readonly BufferBlock _ccs = new BufferBlock(); public override Task Invoke(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken cancellationToken = default(CancellationToken)) { - var cc = new CallContext(interfaceId, methodId, args, cancellationToken); + var cc = new TestCallContext(interfaceId, methodId, args, cancellationToken); Assert.IsTrue(_ccs.Post(cc)); return cc.Result.Task; } - public Task WhenCalled => _ccs.ReceiveAsync(); + public Task WhenCalled => _ccs.ReceiveAsync(); } } diff --git a/Capnp.Net.Runtime.Tests/CallContext.cs b/Capnp.Net.Runtime.Tests/TestCallContext.cs similarity index 80% rename from Capnp.Net.Runtime.Tests/CallContext.cs rename to Capnp.Net.Runtime.Tests/TestCallContext.cs index 07d5329..21dc3d2 100644 --- a/Capnp.Net.Runtime.Tests/CallContext.cs +++ b/Capnp.Net.Runtime.Tests/TestCallContext.cs @@ -4,9 +4,9 @@ using Capnp.Rpc; namespace Capnp.Net.Runtime.Tests { - class CallContext + class TestCallContext { - public CallContext(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken ct) + public TestCallContext(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken ct) { InterfaceId = interfaceId; MethodId = methodId; diff --git a/Capnp.Net.Runtime/DeserializerState.cs b/Capnp.Net.Runtime/DeserializerState.cs index 73f44df..830ff83 100644 --- a/Capnp.Net.Runtime/DeserializerState.cs +++ b/Capnp.Net.Runtime/DeserializerState.cs @@ -47,7 +47,7 @@ namespace Capnp /// /// The capabilities imported from the capability table. Only valid in RPC context. /// - public IReadOnlyList Caps { get; set; } + public IList Caps { get; set; } /// /// Current segment (essentially Segments[CurrentSegmentIndex] /// diff --git a/Capnp.Net.Runtime/Reserializing.cs b/Capnp.Net.Runtime/Reserializing.cs index 8f0c964..a2bbf6c 100644 --- a/Capnp.Net.Runtime/Reserializing.cs +++ b/Capnp.Net.Runtime/Reserializing.cs @@ -25,6 +25,12 @@ namespace Capnp if (to == null) throw new ArgumentNullException(nameof(to)); + if (from.Caps != null && to.Caps != null) + { + to.Caps.Clear(); + to.Caps.AddRange(from.Caps); + } + var ds = to.Rewrap(); IReadOnlyList items; diff --git a/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs b/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs new file mode 100644 index 0000000..2e4b704 --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/CallContext.cs @@ -0,0 +1,238 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Capnp.Rpc.Interception +{ + + /// + /// Context of an intercepted call. Provides access to parameters and results, + /// and the possibility to redirect the call to some other capability. + /// + public class CallContext + { + class PromisedAnswer : IPromisedAnswer + { + CallContext _callContext; + TaskCompletionSource _futureResult = new TaskCompletionSource(); + + public PromisedAnswer(CallContext callContext) + { + _callContext = callContext; + } + + public Task WhenReturned => _futureResult.Task; + + async Task AccessWhenReturned(MemberAccessPath access) + { + await WhenReturned; + return new Proxy(Access(access)); + } + + public ConsumedCapability Access(MemberAccessPath access) + { + if (_futureResult.Task.IsCompleted) + { + try + { + return access.Eval(WhenReturned.Result); + } + catch (AggregateException exception) + { + throw exception.InnerException; + } + } + else + { + return new LazyCapability(AccessWhenReturned(access)); + } + } + + public void Dispose() + { + } + + public void Return() + { + if (_callContext.IsCanceled) + { + _futureResult.SetCanceled(); + } + else if (_callContext.Exception != null) + { + _futureResult.SetException(new RpcException(_callContext.Exception)); + } + else + { + _futureResult.SetResult(_callContext.OutArgs); + } + } + } + + public ulong InterfaceId { get; } + public ushort MethodId { get; } + public bool IsTailCall { get; } + public InterceptionState State { get; private set; } + public SerializerState InArgs { get; set; } + public DeserializerState OutArgs { get; set; } + public string Exception { get; set; } + public bool IsCanceled { get; set; } + public object Bob + { + get => _bob; + set + { + if (value != _bob) + { + BobProxy?.Dispose(); + BobProxy = null; + + _bob = value; + + switch (value) + { + case Proxy proxy: + BobProxy = proxy; + break; + + case Skeleton skeleton: + BobProxy = CapabilityReflection.CreateProxy( + LocalCapability.Create(skeleton)); + break; + + case ConsumedCapability cap: + BobProxy = CapabilityReflection.CreateProxy(cap); + break; + + case null: + break; + + default: + BobProxy = CapabilityReflection.CreateProxy( + LocalCapability.Create( + Skeleton.GetOrCreateSkeleton(value, false))); + break; + } + } + } + } + + internal Proxy BobProxy { get; private set; } + + readonly CensorCapability _censorCapability; + PromisedAnswer _promisedAnswer; + object _bob; + + internal IPromisedAnswer Answer => _promisedAnswer; + + internal CallContext(CensorCapability censorCapability, ulong interfaceId, ushort methodId, SerializerState inArgs) + { + _censorCapability = censorCapability; + _promisedAnswer = new PromisedAnswer(this); + + Bob = censorCapability.InterceptedCapability; + InterfaceId = interfaceId; + MethodId = methodId; + InArgs = inArgs; + State = InterceptionState.RequestedFromAlice; + } + + static void InterceptCaps(DeserializerState state, IInterceptionPolicy policy) + { + if (state.Caps != null) + { + for (int i = 0; i < state.Caps.Count; i++) + { + state.Caps[i] = policy.Attach(state.Caps[i]); + state.Caps[i].AddRef(); + } + } + } + + static void UninterceptCaps(DeserializerState state, IInterceptionPolicy policy) + { + if (state.Caps != null) + { + for (int i = 0; i < state.Caps.Count; i++) + { + state.Caps[i] = policy.Detach(state.Caps[i]); + state.Caps[i].AddRef(); + } + } + } + + public void InterceptInCaps(IInterceptionPolicy policyOverride = null) + { + InterceptCaps(InArgs, policyOverride ?? _censorCapability.Policy); + } + + public void InterceptOutCaps(IInterceptionPolicy policyOverride = null) + { + InterceptCaps(OutArgs, policyOverride ?? _censorCapability.Policy); + } + + public void UninterceptInCaps(IInterceptionPolicy policyOverride = null) + { + UninterceptCaps(InArgs, policyOverride ?? _censorCapability.Policy); + } + + public void UninterceptOutCaps(IInterceptionPolicy policyOverride = null) + { + UninterceptCaps(OutArgs, policyOverride ?? _censorCapability.Policy); + } + + public void ForwardToBob(CancellationToken cancellationToken = default) + { + if (Bob == null) + { + throw new InvalidOperationException("Bob is null"); + } + + var answer = BobProxy.Call(InterfaceId, MethodId, InArgs.Rewrap(), IsTailCall, cancellationToken); + + State = InterceptionState.ForwardedToBob; + + async void ChangeStateWhenReturned() + { + using (answer) + { + try + { + OutArgs = await answer.WhenReturned; + } + catch (TaskCanceledException) + { + IsCanceled = true; + } + catch (System.Exception exception) + { + Exception = exception.Message; + } + } + + State = InterceptionState.ReturnedFromBob; + + _censorCapability.Policy.OnReturnFromBob(this); + } + + ChangeStateWhenReturned(); + } + + public void ReturnToAlice() + { + try + { + _promisedAnswer.Return(); + } + catch (InvalidOperationException) + { + throw new InvalidOperationException("The call was already returned"); + } + + State = InterceptionState.ReturnedToAlice; + } + } +} diff --git a/Capnp.Net.Runtime/Rpc/Interception/CensorCapability.cs b/Capnp.Net.Runtime/Rpc/Interception/CensorCapability.cs new file mode 100644 index 0000000..4fc0feb --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/CensorCapability.cs @@ -0,0 +1,44 @@ +namespace Capnp.Rpc.Interception +{ + class CensorCapability : RefCountingCapability + { + public CensorCapability(ConsumedCapability interceptedCapability, IInterceptionPolicy policy) + { + InterceptedCapability = interceptedCapability; + interceptedCapability.AddRef(); + Policy = policy; + MyVine = Vine.Create(this); + } + + public ConsumedCapability InterceptedCapability { get; } + public IInterceptionPolicy Policy { get; } + internal Skeleton MyVine { get; } + + protected override void ReleaseRemotely() + { + InterceptedCapability.Release(); + } + + internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall) + { + var cc = new CallContext(this, interfaceId, methodId, args); + Policy.OnCallFromAlice(cc); + return cc.Answer; + } + + internal override void Export(IRpcEndpoint endpoint, CapDescriptor.WRITER writer) + { + writer.which = CapDescriptor.WHICH.SenderHosted; + writer.SenderHosted = endpoint.AllocateExport(MyVine, out bool _); + } + + internal override void Freeze(out IRpcEndpoint boundEndpoint) + { + boundEndpoint = null; + } + + internal override void Unfreeze() + { + } + } +} diff --git a/Capnp.Net.Runtime/Rpc/Interception/IInterceptionPolicy.cs b/Capnp.Net.Runtime/Rpc/Interception/IInterceptionPolicy.cs new file mode 100644 index 0000000..27bc284 --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/IInterceptionPolicy.cs @@ -0,0 +1,10 @@ +using System; + +namespace Capnp.Rpc.Interception +{ + public interface IInterceptionPolicy: IEquatable + { + void OnCallFromAlice(CallContext callContext); + void OnReturnFromBob(CallContext callContext); + } +} diff --git a/Capnp.Net.Runtime/Rpc/Interception/InterceptionState.cs b/Capnp.Net.Runtime/Rpc/Interception/InterceptionState.cs new file mode 100644 index 0000000..67eea0e --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/InterceptionState.cs @@ -0,0 +1,28 @@ +namespace Capnp.Rpc.Interception +{ + /// + /// The state of an intercepted call from Alice to Bob. + /// + public enum InterceptionState + { + /// + /// Alice initiated the call, but it was neither forwarded to Bob nor finished. + /// + RequestedFromAlice, + + /// + /// The call was forwarded to Bob. + /// + ForwardedToBob, + + /// + /// The call returned from Bob (to whom it was forwarded), but no result was yet forwarded to Alice. + /// + ReturnedFromBob, + + /// + /// The call was returned to Alice (either with results, exception, or cancelled) + /// + ReturnedToAlice + } +} diff --git a/Capnp.Net.Runtime/Rpc/Interception/Interceptor.cs b/Capnp.Net.Runtime/Rpc/Interception/Interceptor.cs new file mode 100644 index 0000000..7c01120 --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/Interception/Interceptor.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; + +namespace Capnp.Rpc.Interception +{ + public static class Interceptor + { + static readonly ConditionalWeakTable _interceptMap = + new ConditionalWeakTable(); + + public static TCap Attach(this IInterceptionPolicy policy, TCap cap) + where TCap: class + { + if (policy == null) + throw new ArgumentNullException(nameof(policy)); + + if (cap == null) + throw new ArgumentNullException(nameof(cap)); + + var cur = cap as CensorCapability; + + while (cur != null) + { + if (policy.Equals(cur.Policy)) + { + return cap; + } + + cur = cur.InterceptedCapability as CensorCapability; + } + + switch (cap) + { + case Proxy proxy: + return CapabilityReflection.CreateProxy(Attach(policy, proxy.ConsumedCap)) as TCap; + + case ConsumedCapability ccap: + return new CensorCapability(ccap, policy) as TCap; + + default: + return Attach(policy, + CapabilityReflection.CreateProxy( + LocalCapability.Create( + Skeleton.GetOrCreateSkeleton(cap, false))) as TCap); + } + } + + public static TCap Detach(this IInterceptionPolicy policy, TCap cap) + where TCap: class + { + if (policy == null) + throw new ArgumentNullException(nameof(policy)); + + if (cap == null) + throw new ArgumentNullException(nameof(cap)); + + switch (cap) + { + case Proxy proxy: + return CapabilityReflection.CreateProxy(Detach(policy, proxy.ConsumedCap)) as TCap; + + case CensorCapability ccap: + { + var cur = ccap; + var stk = new Stack(); + + do + { + if (policy.Equals(cur.Policy)) + { + var cur2 = cur.InterceptedCapability; + + foreach (var p in stk) + { + cur2 = p.Attach(cur2); + } + return cur2 as TCap; + } + + stk.Push(cur.Policy); + cur = cur.InterceptedCapability as CensorCapability; + } + while (cur != null); + + return ccap as TCap; + } + + default: + return cap; + } + } + } +} diff --git a/Capnp.Net.Runtime/Rpc/PendingAnswer.cs b/Capnp.Net.Runtime/Rpc/PendingAnswer.cs index 2ef55af..83bdcd6 100644 --- a/Capnp.Net.Runtime/Rpc/PendingAnswer.cs +++ b/Capnp.Net.Runtime/Rpc/PendingAnswer.cs @@ -38,40 +38,6 @@ namespace Capnp.Rpc } } - //public Task WhenReady => ChainedAwaitWhenReady(); - - //public void Pipeline(PromisedAnswer.READER rd, Action action, Action error) - //{ - // lock (_reentrancyBlocker) - // { - // if (_chainedTask == null) - // { - // _chainedTask = InitialAwaitWhenReady(); - // } - - // _chainedTask = _chainedTask.ContinueWith(t => - // { - // bool rethrow = true; - - // try - // { - // t.Wait(); - // rethrow = false; - // EvaluateProxyAndCallContinuation(rd, action); - // } - // catch (AggregateException aggregateException) - // { - // var innerException = aggregateException.InnerException; - - // error(innerException); - - // if (rethrow) throw innerException; - // } - // }, - // TaskContinuationOptions.ExecuteSynchronously); - // } - //} - async Task AwaitChainedTask(Task chainedTask, Func, Task> func) { try diff --git a/Capnp.Net.Runtime/Rpc/RpcEngine.cs b/Capnp.Net.Runtime/Rpc/RpcEngine.cs index d9e2f0a..316e048 100644 --- a/Capnp.Net.Runtime/Rpc/RpcEngine.cs +++ b/Capnp.Net.Runtime/Rpc/RpcEngine.cs @@ -1311,7 +1311,7 @@ namespace Capnp.Rpc } } - public IReadOnlyList ImportCapTable(Payload.READER payload) + public IList ImportCapTable(Payload.READER payload) { var list = new List(); diff --git a/Capnp.Net.Runtime/Rpc/Skeleton.cs b/Capnp.Net.Runtime/Rpc/Skeleton.cs index 2be39c9..856f2a1 100644 --- a/Capnp.Net.Runtime/Rpc/Skeleton.cs +++ b/Capnp.Net.Runtime/Rpc/Skeleton.cs @@ -39,6 +39,9 @@ namespace Capnp.Rpc internal static Skeleton GetOrCreateSkeleton(T impl, bool addRef) where T: class { + if (impl == null) + throw new ArgumentNullException(nameof(impl)); + if (impl is Skeleton skel) return skel;