Merge pull request #26 from c80k/interception

Interception
This commit is contained in:
c80k 2019-11-07 14:21:01 +01:00 committed by GitHub
commit 233d9b5e84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1237 additions and 165 deletions

View File

@ -9,11 +9,11 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<Compile Include="..\Capnp.Net.Runtime.Tests\CallContext.cs" Link="CallContext.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\DeserializationTests.cs" Link="DeserializationTests.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\DeserializationTests.cs" Link="DeserializationTests.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\DynamicSerializerStateTests.cs" Link="DynamicSerializerStateTests.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\DynamicSerializerStateTests.cs" Link="DynamicSerializerStateTests.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\FramePumpTests.cs" Link="FramePumpTests.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\FramePumpTests.cs" Link="FramePumpTests.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\General.cs" Link="General.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\General.cs" Link="General.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\Interception.cs" Link="Interception.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\Issue19.cs" Link="Issue19.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\Issue19.cs" Link="Issue19.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\Issue20.cs" Link="Issue20.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\Issue20.cs" Link="Issue20.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\Issue25.cs" Link="Issue25.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\Issue25.cs" Link="Issue25.cs" />
@ -30,6 +30,7 @@
<Compile Include="..\Capnp.Net.Runtime.Tests\TcpRpcStress.cs" Link="TcpRpcStress.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\TcpRpcStress.cs" Link="TcpRpcStress.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\test.cs" Link="test.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\test.cs" Link="test.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\TestBase.cs" Link="TestBase.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\TestBase.cs" Link="TestBase.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\TestCallContext.cs" Link="TestCallContext.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\TestCapImplementations.cs" Link="TestCapImplementations.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\TestCapImplementations.cs" Link="TestCapImplementations.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\TestInterfaces.cs" Link="TestInterfaces.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\TestInterfaces.cs" Link="TestInterfaces.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\WirePointerTests.cs" Link="WirePointerTests.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\WirePointerTests.cs" Link="WirePointerTests.cs" />

View File

@ -0,0 +1,624 @@
using Capnp.Net.Runtime.Tests.GenImpls;
using Capnp.Rpc;
using Capnp.Rpc.Interception;
using Capnproto_test.Capnp.Test;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
namespace Capnp.Net.Runtime.Tests
{
[TestClass]
public class Interception: TestBase
{
class MyPolicy : IInterceptionPolicy
{
readonly string _id;
readonly BufferBlock<CallContext> _callSubject = new BufferBlock<CallContext>();
readonly BufferBlock<CallContext> _returnSubject = new BufferBlock<CallContext>();
public MyPolicy(string id)
{
_id = id;
}
public bool Equals(IInterceptionPolicy other)
{
return other is MyPolicy myPolicy && _id.Equals(myPolicy._id);
}
public override bool Equals(object obj)
{
return obj is IInterceptionPolicy other && Equals(other);
}
public override int GetHashCode()
{
return _id.GetHashCode();
}
public void OnCallFromAlice(CallContext callContext)
{
Assert.IsTrue(_callSubject.Post(callContext));
}
public void OnReturnFromBob(CallContext callContext)
{
Assert.IsTrue(_returnSubject.Post(callContext));
}
public IReceivableSourceBlock<CallContext> Calls => _callSubject;
public IReceivableSourceBlock<CallContext> Returns => _returnSubject;
}
[TestMethod]
public void InterceptServerSideObserveCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = policy.Attach<ITestInterface>(new TestInterfaceImpl(counters));
using (var main = client.GetMain<ITestInterface>())
{
var request1 = main.Foo(123, true, default);
var fcc = policy.Calls.ReceiveAsync();
Assert.IsTrue(fcc.Wait(MediumNonDbgTimeout));
var cc = fcc.Result;
var pr = new Capnproto_test.Capnp.Test.TestInterface.Params_foo.READER(cc.InArgs);
Assert.AreEqual(123u, pr.I);
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
var rr = new Capnproto_test.Capnp.Test.TestInterface.Result_foo.READER(cc.OutArgs);
Assert.AreEqual("foo", rr.X);
cc.ReturnToAlice();
Assert.IsTrue(request1.Wait(MediumNonDbgTimeout));
Assert.AreEqual("foo", request1.Result);
}
}
}
[TestMethod]
public void InterceptClientSideModifyCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.AreEqual(InterceptionState.RequestedFromAlice, cc.State);
var pr = new Capnproto_test.Capnp.Test.TestInterface.Params_foo.READER(cc.InArgs);
Assert.AreEqual(321u, pr.I);
Assert.AreEqual(false, pr.J);
var pw = cc.InArgs.Rewrap<Capnproto_test.Capnp.Test.TestInterface.Params_foo.WRITER>();
pw.I = 123u;
pw.J = true;
cc.ForwardToBob();
var rx = policy.Returns.ReceiveAsync();
// Racing against Bob's answer
Assert.IsTrue(cc.State == InterceptionState.ForwardedToBob || rx.IsCompleted);
Assert.IsTrue(rx.Wait(MediumNonDbgTimeout));
var rr = new Capnproto_test.Capnp.Test.TestInterface.Result_foo.READER(cc.OutArgs);
Assert.AreEqual("foo", rr.X);
Assert.IsFalse(request1.IsCompleted);
var rw = ((DynamicSerializerState)cc.OutArgs).Rewrap<Capnproto_test.Capnp.Test.TestInterface.Result_foo.WRITER>();
rw.X = "bar";
cc.OutArgs = rw;
Assert.AreEqual(InterceptionState.ReturnedFromBob, cc.State);
cc.ReturnToAlice();
Assert.AreEqual(InterceptionState.ReturnedToAlice, cc.State);
Assert.IsTrue(request1.IsCompleted);
Assert.AreEqual("bar", request1.Result);
}
}
}
[TestMethod]
public void InterceptClientSideShortCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
var rw = SerializerState.CreateForRpc<Capnproto_test.Capnp.Test.TestInterface.Result_foo.WRITER>();
rw.X = "bar";
cc.OutArgs = rw;
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.AreEqual("bar", request1.Result);
}
}
}
[TestMethod]
public void InterceptClientSideRejectCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
cc.Exception = "rejected";
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.IsTrue(request1.IsFaulted);
Assert.AreEqual("rejected", request1.Exception.InnerException.Message);
}
}
}
[TestMethod]
public void InterceptClientSideCancelReturn()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
Assert.IsFalse(cc.CancelFromAlice.IsCancellationRequested);
cc.ReturnCanceled = true;
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.IsTrue(request1.IsCanceled);
}
}
}
[TestMethod]
public void InterceptClientSideOverrideCanceledCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, new CancellationToken(true));
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
Assert.IsTrue(cc.CancelFromAlice.IsCancellationRequested);
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
Assert.IsTrue(cc.ReturnCanceled);
cc.ReturnCanceled = false;
cc.Exception = "Cancelled";
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.IsTrue(request1.IsFaulted);
}
}
}
[TestMethod]
public void InterceptClientSideRedirectCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(123, true, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
var counters2 = new Counters();
var impl2 = new TestInterfaceImpl(counters2);
cc.Bob = impl2;
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.AreEqual("foo", request1.Result);
Assert.AreEqual(0, counters.CallCount);
Assert.AreEqual(1, counters2.CallCount);
}
}
}
[TestMethod]
public void InterfaceAndMethodId()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var baz = main.Baz(new TestAllTypes());
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsTrue(cc.MethodId == 2);
Assert.AreEqual(new TestInterface_Skeleton().InterfaceId, cc.InterfaceId);
}
}
}
[TestMethod]
public void TailCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestTailCallerImpl(counters);
using (var main = client.GetMain<ITestTailCaller>())
{
var calleeCallCount = new Counters();
var callee = policy.Attach<ITestTailCallee>(new TestTailCalleeImpl(calleeCallCount));
var promise = main.Foo(456, callee, default);
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
var cc = ccf.Result;
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
Assert.IsTrue(promise.Wait(MediumNonDbgTimeout));
Assert.AreEqual("from TestTailCaller", promise.Result.T);
}
}
}
[TestMethod]
public void InterceptInCaps()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestMoreStuffImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestMoreStuff>()))
{
var counters2 = new Counters();
var cap = new TestInterfaceImpl(counters2);
var promise = main.CallFoo(cap);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.InterceptInCaps();
cc.ForwardToBob();
var cc2f = policy.Calls.ReceiveAsync();
Assert.IsTrue(cc2f.Wait(MediumNonDbgTimeout));
var cc2 = cc2f.Result;
cc2.ForwardToBob();
var cc2fr = policy.Returns.ReceiveAsync();
Assert.IsTrue(cc2fr.Wait(MediumNonDbgTimeout));
Assert.AreSame(cc2, cc2fr.Result);
Assert.AreEqual(1, counters2.CallCount);
cc2.ReturnToAlice();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
Assert.AreSame(cc, ccfr.Result);
cc.ReturnToAlice();
Assert.IsTrue(promise.Wait(MediumNonDbgTimeout));
}
}
}
[TestMethod]
public void InterceptOutCaps()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = policy.Attach<ITestMoreStuff>(new TestMoreStuffImpl(counters));
using (var main = client.GetMain<ITestMoreStuff>())
{
var counters2 = new Counters();
var cap = new TestInterfaceImpl(counters2);
main.Hold(cap);
{
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
ccf.Result.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
ccf.Result.ReturnToAlice();
}
var ghf = main.GetHeld();
{
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
ccf.Result.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
ccf.Result.InterceptOutCaps();
ccf.Result.ReturnToAlice();
}
Assert.IsTrue(ghf.Wait(MediumNonDbgTimeout));
var held = ghf.Result;
var foof = held.Foo(123, true);
{
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
ccf.Result.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
ccf.Result.ReturnToAlice();
}
Assert.IsTrue(foof.Wait(MediumNonDbgTimeout));
}
}
}
[TestMethod]
public void UninterceptOutCaps()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestMoreStuffImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestMoreStuff>()))
{
var counters2 = new Counters();
var cap = new TestInterfaceImpl(counters2);
main.Hold(cap);
{
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.InterceptInCaps();
cc.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
}
main.CallHeld();
{
// CallHeld
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.ForwardToBob();
// actual call on held cap.
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
ccf.Result.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
ccf.Result.ReturnToAlice();
ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
}
var ghf = main.GetHeld();
{
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.InterceptInCaps();
cc.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
cc.UninterceptOutCaps();
cc.ReturnToAlice();
}
Assert.IsTrue(ghf.Wait(MediumNonDbgTimeout));
var held = ghf.Result;
var foof = held.Foo(123, true);
Assert.IsTrue(foof.Wait(MediumNonDbgTimeout));
}
}
}
[TestMethod]
public void UninterceptInCaps()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestMoreStuffImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestMoreStuff>()))
{
var counters2 = new Counters();
var cap = policy.Attach<ITestInterface>(new TestInterfaceImpl(counters2));
var foof = main.CallFoo(cap);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.UninterceptInCaps();
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
Assert.IsTrue(foof.Wait(MediumNonDbgTimeout));
}
}
}
[TestMethod]
public void MultiAttachAndDetach()
{
var a = new MyPolicy("a");
var b = new MyPolicy("b");
var c = new MyPolicy("c");
var counters = new Counters();
var impl = new TestInterfaceImpl(counters);
var implA = a.Attach<ITestInterface>(impl);
var implAbc = b.Attach(a.Attach(b.Attach(c.Attach(implA))));
var implAc = b.Detach(implAbc);
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
server.Main = implAc;
using (var main = client.GetMain<ITestInterface>())
{
var foof = main.Foo(123, true);
var ccf1 = c.Calls.ReceiveAsync();
Assert.IsTrue(ccf1.Wait(MediumNonDbgTimeout));
var cc1 = ccf1.Result;
cc1.ForwardToBob();
var ccf2 = a.Calls.ReceiveAsync();
Assert.IsTrue(ccf2.Wait(MediumNonDbgTimeout));
var cc2 = ccf2.Result;
cc2.ForwardToBob();
Assert.IsTrue(a.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc2.ReturnToAlice();
Assert.IsTrue(c.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc1.ReturnToAlice();
Assert.IsTrue(foof.Wait(MediumNonDbgTimeout));
}
}
}
}
}

View File

@ -8,16 +8,16 @@ namespace Capnp.Net.Runtime.Tests
{ {
class ProvidedCapabilityMultiCallMock : Skeleton class ProvidedCapabilityMultiCallMock : Skeleton
{ {
readonly BufferBlock<CallContext> _ccs = new BufferBlock<CallContext>(); readonly BufferBlock<TestCallContext> _ccs = new BufferBlock<TestCallContext>();
public override Task<AnswerOrCounterquestion> Invoke(ulong interfaceId, ushort methodId, public override Task<AnswerOrCounterquestion> Invoke(ulong interfaceId, ushort methodId,
DeserializerState args, CancellationToken cancellationToken = default(CancellationToken)) DeserializerState args, CancellationToken cancellationToken = default(CancellationToken))
{ {
var cc = new CallContext(interfaceId, methodId, args, cancellationToken); var cc = new TestCallContext(interfaceId, methodId, args, cancellationToken);
Assert.IsTrue(_ccs.Post(cc)); Assert.IsTrue(_ccs.Post(cc));
return cc.Result.Task; return cc.Result.Task;
} }
public Task<CallContext> WhenCalled => _ccs.ReceiveAsync(); public Task<TestCallContext> WhenCalled => _ccs.ReceiveAsync();
} }
} }

View File

@ -141,7 +141,7 @@ namespace Capnp.Net.Runtime.Tests
var args = DynamicSerializerState.CreateForRpc(); var args = DynamicSerializerState.CreateForRpc();
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
using (var answer = main.Call(0x1234567812345678, 0x3333, args, false)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
(var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result;
@ -182,7 +182,7 @@ namespace Capnp.Net.Runtime.Tests
var args = DynamicSerializerState.CreateForRpc(); var args = DynamicSerializerState.CreateForRpc();
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
using (var answer = main.Call(0x1234567812345678, 0x3333, args, false)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
(var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result;
@ -223,7 +223,7 @@ namespace Capnp.Net.Runtime.Tests
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
CancellationToken ctx; CancellationToken ctx;
using (var answer = main.Call(0x1234567812345678, 0x3333, args, false)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
(var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result;
@ -267,7 +267,7 @@ namespace Capnp.Net.Runtime.Tests
args.WriteData(0, 123456); args.WriteData(0, 123456);
CancellationToken ctx; CancellationToken ctx;
IPromisedAnswer answer; IPromisedAnswer answer;
using (answer = main.Call(0x1234567812345678, 0x3333, args, false)) using (answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
(var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result;
@ -324,7 +324,7 @@ namespace Capnp.Net.Runtime.Tests
var args = DynamicSerializerState.CreateForRpc(); var args = DynamicSerializerState.CreateForRpc();
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
using (var answer = main.Call(0x1234567812345678, 0x3333, args, false)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
(var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result;
@ -361,7 +361,7 @@ namespace Capnp.Net.Runtime.Tests
var args = DynamicSerializerState.CreateForRpc(); var args = DynamicSerializerState.CreateForRpc();
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
@ -372,7 +372,7 @@ namespace Capnp.Net.Runtime.Tests
args2.SetStruct(1, 0); args2.SetStruct(1, 0);
args2.WriteData(0, 654321); args2.WriteData(0, 654321);
using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2, false)) using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2))
{ {
(var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result;
Assert.AreEqual<ulong>(0x1234567812345678, interfaceId); Assert.AreEqual<ulong>(0x1234567812345678, interfaceId);
@ -434,7 +434,7 @@ namespace Capnp.Net.Runtime.Tests
var args = DynamicSerializerState.CreateForRpc(); var args = DynamicSerializerState.CreateForRpc();
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
@ -464,7 +464,7 @@ namespace Capnp.Net.Runtime.Tests
args2.SetStruct(1, 0); args2.SetStruct(1, 0);
args2.WriteData(0, 654321); args2.WriteData(0, 654321);
using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2, false)) using (var answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2))
{ {
Assert.IsTrue(answer.WhenReturned.Wait(MediumTimeout)); Assert.IsTrue(answer.WhenReturned.Wait(MediumTimeout));
Assert.IsTrue(mock2.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock2.WhenCalled.Wait(MediumTimeout));
@ -510,7 +510,7 @@ namespace Capnp.Net.Runtime.Tests
var args = DynamicSerializerState.CreateForRpc(); var args = DynamicSerializerState.CreateForRpc();
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
@ -524,8 +524,8 @@ namespace Capnp.Net.Runtime.Tests
args3.SetStruct(1, 0); args3.SetStruct(1, 0);
args3.WriteData(0, 222222); args3.WriteData(0, 222222);
using (var answer2 = pipelined.Call(0x1111111111111111, 0x1111, args2, false)) using (var answer2 = pipelined.Call(0x1111111111111111, 0x1111, args2))
using (var answer3 = pipelined.Call(0x2222222222222222, 0x2222, args3, false)) using (var answer3 = pipelined.Call(0x2222222222222222, 0x2222, args3))
{ {
(var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result; (var interfaceId, var methodId, var inargs, var ct) = mock.WhenCalled.Result;
@ -555,8 +555,8 @@ namespace Capnp.Net.Runtime.Tests
args5.SetStruct(1, 0); args5.SetStruct(1, 0);
args5.WriteData(0, 444444); args5.WriteData(0, 444444);
using (var answer4 = pipelined.Call(0x3333333333333333, 0x3333, args4, false)) using (var answer4 = pipelined.Call(0x3333333333333333, 0x3333, args4))
using (var answer5 = pipelined.Call(0x4444444444444444, 0x4444, args5, false)) using (var answer5 = pipelined.Call(0x4444444444444444, 0x4444, args5))
{ {
var call2 = mock2.WhenCalled; var call2 = mock2.WhenCalled;
var call3 = mock2.WhenCalled; var call3 = mock2.WhenCalled;
@ -628,7 +628,7 @@ namespace Capnp.Net.Runtime.Tests
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
BareProxy pipelined; BareProxy pipelined;
using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
@ -643,7 +643,7 @@ namespace Capnp.Net.Runtime.Tests
try try
{ {
pipelined.Call(0x8765432187654321, 0x4444, args2, false); pipelined.Call(0x8765432187654321, 0x4444, args2);
Assert.Fail("Expected an exception here"); Assert.Fail("Expected an exception here");
} }
catch (ObjectDisposedException) catch (ObjectDisposedException)
@ -675,7 +675,7 @@ namespace Capnp.Net.Runtime.Tests
args.SetStruct(1, 0); args.SetStruct(1, 0);
args.WriteData(0, 123456); args.WriteData(0, 123456);
IPromisedAnswer answer2; IPromisedAnswer answer2;
using (var answer = main.Call(0x1234567812345678, 0x3333, args, true)) using (var answer = main.Call(0x1234567812345678, 0x3333, args))
{ {
Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout)); Assert.IsTrue(mock.WhenCalled.Wait(MediumTimeout));
@ -685,7 +685,7 @@ namespace Capnp.Net.Runtime.Tests
args2.SetStruct(1, 0); args2.SetStruct(1, 0);
args2.WriteData(0, 654321); args2.WriteData(0, 654321);
answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2, false); answer2 = pipelined.Call(0x8765432187654321, 0x4444, args2);
} }
using (answer2) using (answer2)

View File

@ -4,9 +4,9 @@ using Capnp.Rpc;
namespace Capnp.Net.Runtime.Tests namespace Capnp.Net.Runtime.Tests
{ {
class CallContext class TestCallContext
{ {
public CallContext(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken ct) public TestCallContext(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken ct)
{ {
InterfaceId = interfaceId; InterfaceId = interfaceId;
MethodId = methodId; MethodId = methodId;

View File

@ -47,7 +47,7 @@ namespace Capnp
/// <summary> /// <summary>
/// The capabilities imported from the capability table. Only valid in RPC context. /// The capabilities imported from the capability table. Only valid in RPC context.
/// </summary> /// </summary>
public IReadOnlyList<Rpc.ConsumedCapability> Caps { get; set; } public IList<Rpc.ConsumedCapability> Caps { get; set; }
/// <summary> /// <summary>
/// Current segment (essentially Segments[CurrentSegmentIndex] /// Current segment (essentially Segments[CurrentSegmentIndex]
/// </summary> /// </summary>

View File

@ -25,6 +25,12 @@ namespace Capnp
if (to == null) if (to == null)
throw new ArgumentNullException(nameof(to)); throw new ArgumentNullException(nameof(to));
if (from.Caps != null && to.Caps != null)
{
to.Caps.Clear();
to.Caps.AddRange(from.Caps);
}
var ds = to.Rewrap<DynamicSerializerState>(); var ds = to.Rewrap<DynamicSerializerState>();
IReadOnlyList<DeserializerState> items; IReadOnlyList<DeserializerState> items;

View File

@ -45,9 +45,9 @@
/// <param name="args">Method arguments</param> /// <param name="args">Method arguments</param>
/// <param name="tailCall">Whether it is a tail call</param> /// <param name="tailCall">Whether it is a tail call</param>
/// <returns>Answer promise</returns> /// <returns>Answer promise</returns>
public IPromisedAnswer Call(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall) public IPromisedAnswer Call(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{ {
return base.Call(interfaceId, methodId, args, tailCall); return base.Call(interfaceId, methodId, args, default);
} }
} }
} }

View File

@ -10,7 +10,7 @@ namespace Capnp.Rpc
/// </summary> /// </summary>
public abstract class ConsumedCapability public abstract class ConsumedCapability
{ {
internal abstract IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall); internal abstract IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args);
/// <summary> /// <summary>
/// Request the RPC engine to release this capability from its import table, /// Request the RPC engine to release this capability from its import table,

View File

@ -0,0 +1,339 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Capnp.Rpc.Interception
{
/// <summary>
/// Context of an intercepted call. Provides access to parameters and results,
/// and the possibility to redirect the call to some other capability.
/// </summary>
public class CallContext
{
class PromisedAnswer : IPromisedAnswer
{
readonly CallContext _callContext;
readonly TaskCompletionSource<DeserializerState> _futureResult = new TaskCompletionSource<DeserializerState>();
readonly CancellationTokenSource _cancelFromAlice = new CancellationTokenSource();
public PromisedAnswer(CallContext callContext)
{
_callContext = callContext;
}
public Task<DeserializerState> WhenReturned => _futureResult.Task;
public CancellationToken CancelFromAlice => _cancelFromAlice.Token;
async Task<Proxy> AccessWhenReturned(MemberAccessPath access)
{
await WhenReturned;
return new Proxy(Access(access));
}
public ConsumedCapability Access(MemberAccessPath access)
{
if (_futureResult.Task.IsCompleted)
{
try
{
return access.Eval(WhenReturned.Result);
}
catch (AggregateException exception)
{
throw exception.InnerException;
}
}
else
{
return new LazyCapability(AccessWhenReturned(access));
}
}
public void Dispose()
{
try
{
_cancelFromAlice.Cancel();
}
catch (ObjectDisposedException)
{
// May happen when cancellation request from Alice arrives after return.
}
}
public void Return()
{
try
{
if (_callContext.ReturnCanceled)
{
_futureResult.SetCanceled();
}
else if (_callContext.Exception != null)
{
_futureResult.SetException(new RpcException(_callContext.Exception));
}
else
{
_futureResult.SetResult(_callContext.OutArgs);
}
}
finally
{
_cancelFromAlice.Dispose();
}
}
}
/// <summary>
/// Target interface ID of this call
/// </summary>
public ulong InterfaceId { get; }
/// <summary>
/// Target method ID of this call
/// </summary>
public ushort MethodId { get; }
/// <summary>
/// Lifecycle state of this call
/// </summary>
public InterceptionState State { get; private set; }
/// <summary>
/// Input arguments
/// </summary>
public SerializerState InArgs { get; set; }
/// <summary>
/// Output arguments ("return value")
/// </summary>
public DeserializerState OutArgs { get; set; }
/// <summary>
/// Exception text, or null if there is no exception
/// </summary>
public string Exception { get; set; }
/// <summary>
/// Whether the call should return in canceled state to Alice (the original caller).
/// In case of forwarding (<see cref="ForwardToBob()"/>) the property is automatically set according
/// to the cancellation state of Bob's answer. However, you may override it:
/// <list type="bullet">
/// <item><description>Setting it from 'false' to 'true' means that we pretend Alice a canceled call.
/// If Alice never requested cancellation this will surprise her pretty much.</description></item>
/// <item><description>Setting it from 'true' to 'false' overrides an existing cancellation. Since
/// we did not receive any output arguments from Bob (due to the cancellation), you *must* provide
/// either <see cref="OutArgs"/> or <see cref="Exception"/>.</description></item>
/// </list>
/// </summary>
public bool ReturnCanceled { get; set; }
/// <summary>
/// The cancellation token *from Alice* tells us when the original caller resigns from the call.
/// </summary>
public CancellationToken CancelFromAlice { get; private set; }
/// <summary>
/// The cancellation token *to Bob* tells the target capability when we resign from the forwarded call.
/// It is initialized with <seealso cref="CancelFromAlice"/>. Override it to achieve different behaviors:
/// E.g. set it to <code>CancellationToken.None</code> for "hiding" any cancellation request from Alice.
/// Set it to <code>new CancellationToken(true)</code> to pretend Bob a cancellation request.
/// </summary>
public CancellationToken CancelToBob { get; set; }
/// <summary>
/// Target capability. May be one of the following:
/// <list type="bullet">
/// <item><description>Capability interface implementation</description></item>
/// <item><description>A <see cref="Proxy"/>-derived object</description></item>
/// <item><description>A <see cref="Skeleton"/>-derived object</description></item>
/// <item><description>A <see cref="ConsumedCapability"/>-derived object (low level capability)</description></item>
/// <item><description>null</description></item>
/// </list>
/// </summary>
public object Bob
{
get => _bob;
set
{
if (value != _bob)
{
BobProxy?.Dispose();
BobProxy = null;
_bob = value;
switch (value)
{
case Proxy proxy:
BobProxy = proxy;
break;
case Skeleton skeleton:
BobProxy = CapabilityReflection.CreateProxy<object>(
LocalCapability.Create(skeleton));
break;
case ConsumedCapability cap:
BobProxy = CapabilityReflection.CreateProxy<object>(cap);
break;
case null:
break;
default:
BobProxy = CapabilityReflection.CreateProxy<object>(
LocalCapability.Create(
Skeleton.GetOrCreateSkeleton(value, false)));
break;
}
}
}
}
internal Proxy BobProxy { get; private set; }
readonly CensorCapability _censorCapability;
PromisedAnswer _promisedAnswer;
object _bob;
internal IPromisedAnswer Answer => _promisedAnswer;
internal CallContext(CensorCapability censorCapability, ulong interfaceId, ushort methodId, SerializerState inArgs)
{
_censorCapability = censorCapability;
_promisedAnswer = new PromisedAnswer(this);
CancelFromAlice = _promisedAnswer.CancelFromAlice;
CancelToBob = CancelFromAlice;
Bob = censorCapability.InterceptedCapability;
InterfaceId = interfaceId;
MethodId = methodId;
InArgs = inArgs;
State = InterceptionState.RequestedFromAlice;
}
static void InterceptCaps(DeserializerState state, IInterceptionPolicy policy)
{
if (state.Caps != null)
{
for (int i = 0; i < state.Caps.Count; i++)
{
state.Caps[i] = policy.Attach(state.Caps[i]);
state.Caps[i].AddRef();
}
}
}
static void UninterceptCaps(DeserializerState state, IInterceptionPolicy policy)
{
if (state.Caps != null)
{
for (int i = 0; i < state.Caps.Count; i++)
{
state.Caps[i] = policy.Detach(state.Caps[i]);
state.Caps[i].AddRef();
}
}
}
/// <summary>
/// Intercepts all capabilies inside the input arguments
/// </summary>
/// <param name="policyOverride">Policy to use, or null to further use present policy</param>
public void InterceptInCaps(IInterceptionPolicy policyOverride = null)
{
InterceptCaps(InArgs, policyOverride ?? _censorCapability.Policy);
}
/// <summary>
/// Intercepts all capabilies inside the output arguments
/// </summary>
/// <param name="policyOverride">Policy to use, or null to further use present policy</param>
public void InterceptOutCaps(IInterceptionPolicy policyOverride = null)
{
InterceptCaps(OutArgs, policyOverride ?? _censorCapability.Policy);
}
/// <summary>
/// Unintercepts all capabilies inside the input arguments
/// </summary>
/// <param name="policyOverride">Policy to remove, or null to remove present policy</param>
public void UninterceptInCaps(IInterceptionPolicy policyOverride = null)
{
UninterceptCaps(InArgs, policyOverride ?? _censorCapability.Policy);
}
/// <summary>
/// Unintercepts all capabilies inside the output arguments
/// </summary>
/// <param name="policyOverride">Policy to remove, or null to remove present policy</param>
public void UninterceptOutCaps(IInterceptionPolicy policyOverride = null)
{
UninterceptCaps(OutArgs, policyOverride ?? _censorCapability.Policy);
}
/// <summary>
/// Forwards this intercepted call to the target capability ("Bob").
/// </summary>
/// <param name="cancellationToken">Optional cancellation token, requesting Bob to cancel the call</param>
public void ForwardToBob()
{
if (Bob == null)
{
throw new InvalidOperationException("Bob is null");
}
var answer = BobProxy.Call(InterfaceId, MethodId, InArgs.Rewrap<DynamicSerializerState>(), default, CancelToBob);
State = InterceptionState.ForwardedToBob;
async void ChangeStateWhenReturned()
{
using (answer)
{
try
{
OutArgs = await answer.WhenReturned;
}
catch (TaskCanceledException)
{
ReturnCanceled = true;
}
catch (System.Exception exception)
{
Exception = exception.Message;
}
}
State = InterceptionState.ReturnedFromBob;
_censorCapability.Policy.OnReturnFromBob(this);
}
ChangeStateWhenReturned();
}
/// <summary>
/// Returns this intercepted call to the caller ("Alice").
/// </summary>
public void ReturnToAlice()
{
try
{
_promisedAnswer.Return();
}
catch (InvalidOperationException)
{
throw new InvalidOperationException("The call was already returned");
}
State = InterceptionState.ReturnedToAlice;
}
}
}

View File

@ -0,0 +1,44 @@
namespace Capnp.Rpc.Interception
{
class CensorCapability : RefCountingCapability
{
public CensorCapability(ConsumedCapability interceptedCapability, IInterceptionPolicy policy)
{
InterceptedCapability = interceptedCapability;
interceptedCapability.AddRef();
Policy = policy;
MyVine = Vine.Create(this);
}
public ConsumedCapability InterceptedCapability { get; }
public IInterceptionPolicy Policy { get; }
internal Skeleton MyVine { get; }
protected override void ReleaseRemotely()
{
InterceptedCapability.Release();
}
internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{
var cc = new CallContext(this, interfaceId, methodId, args);
Policy.OnCallFromAlice(cc);
return cc.Answer;
}
internal override void Export(IRpcEndpoint endpoint, CapDescriptor.WRITER writer)
{
writer.which = CapDescriptor.WHICH.SenderHosted;
writer.SenderHosted = endpoint.AllocateExport(MyVine, out bool _);
}
internal override void Freeze(out IRpcEndpoint boundEndpoint)
{
boundEndpoint = null;
}
internal override void Unfreeze()
{
}
}
}

View File

@ -0,0 +1,23 @@
using System;
namespace Capnp.Rpc.Interception
{
/// <summary>
/// An interception policy implements callbacks for outgoing calls and returning forwarded calls.
/// </summary>
public interface IInterceptionPolicy: IEquatable<IInterceptionPolicy>
{
/// <summary>
/// A caller ("Alice") initiated a new call, which is now intercepted.
/// </summary>
/// <param name="callContext">Context object</param>
void OnCallFromAlice(CallContext callContext);
/// <summary>
/// Given that the intercepted call was forwarded, it returned now from the target ("Bob")
/// and may (or may not) be returned to the original caller ("Alice").
/// </summary>
/// <param name="callContext"></param>
void OnReturnFromBob(CallContext callContext);
}
}

View File

@ -0,0 +1,28 @@
namespace Capnp.Rpc.Interception
{
/// <summary>
/// The state of an intercepted call from Alice to Bob.
/// </summary>
public enum InterceptionState
{
/// <summary>
/// Alice initiated the call, but it was neither forwarded to Bob nor finished.
/// </summary>
RequestedFromAlice,
/// <summary>
/// The call was forwarded to Bob.
/// </summary>
ForwardedToBob,
/// <summary>
/// The call returned from Bob (to whom it was forwarded), but no result was yet forwarded to Alice.
/// </summary>
ReturnedFromBob,
/// <summary>
/// The call was returned to Alice (either with results, exception, or cancelled)
/// </summary>
ReturnedToAlice
}
}

View File

@ -0,0 +1,115 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
namespace Capnp.Rpc.Interception
{
/// <summary>
/// This static class provides extension methods for intercepting and unintercepting capabilities.
/// </summary>
public static class Interceptor
{
static readonly ConditionalWeakTable<ConsumedCapability, CensorCapability> _interceptMap =
new ConditionalWeakTable<ConsumedCapability, CensorCapability>();
/// <summary>
/// Attach this policy to given capability.
/// </summary>
/// <typeparam name="TCap">Capability interface type</typeparam>
/// <param name="policy">Policy to attach</param>
/// <param name="cap">Capability to censor</param>
/// <returns>Censored capability instance</returns>
/// <exception cref="ArgumentNullException"><paramref name="policy"/> is null or
/// <paramref name="cap"/> is null</exception>
public static TCap Attach<TCap>(this IInterceptionPolicy policy, TCap cap)
where TCap: class
{
if (policy == null)
throw new ArgumentNullException(nameof(policy));
if (cap == null)
throw new ArgumentNullException(nameof(cap));
var cur = cap as CensorCapability;
while (cur != null)
{
if (policy.Equals(cur.Policy))
{
return cap;
}
cur = cur.InterceptedCapability as CensorCapability;
}
switch (cap)
{
case Proxy proxy:
return CapabilityReflection.CreateProxy<TCap>(Attach(policy, proxy.ConsumedCap)) as TCap;
case ConsumedCapability ccap:
return new CensorCapability(ccap, policy) as TCap;
default:
return Attach(policy,
CapabilityReflection.CreateProxy<TCap>(
LocalCapability.Create(
Skeleton.GetOrCreateSkeleton(cap, false))) as TCap);
}
}
/// <summary>
/// Detach this policy from given (censored) capability.
/// </summary>
/// <typeparam name="TCap">Capability interface type</typeparam>
/// <param name="policy">Policy to detach</param>
/// <param name="cap">Capability to clean</param>
/// <returns>Clean capability instance (at least, without this interception policy)</returns>
/// <exception cref="ArgumentNullException"><paramref name="policy"/> is null or
/// <paramref name="cap"/> is null</exception>
public static TCap Detach<TCap>(this IInterceptionPolicy policy, TCap cap)
where TCap: class
{
if (policy == null)
throw new ArgumentNullException(nameof(policy));
if (cap == null)
throw new ArgumentNullException(nameof(cap));
switch (cap)
{
case Proxy proxy:
return CapabilityReflection.CreateProxy<TCap>(Detach(policy, proxy.ConsumedCap)) as TCap;
case CensorCapability ccap:
{
var cur = ccap;
var stk = new Stack<IInterceptionPolicy>();
do
{
if (policy.Equals(cur.Policy))
{
var cur2 = cur.InterceptedCapability;
foreach (var p in stk)
{
cur2 = p.Attach(cur2);
}
return cur2 as TCap;
}
stk.Push(cur.Policy);
cur = cur.InterceptedCapability as CensorCapability;
}
while (cur != null);
return ccap as TCap;
}
default:
return cap;
}
}
}
}

View File

@ -81,9 +81,7 @@ namespace Capnp.Rpc
public Task<Proxy> WhenResolved { get; } public Task<Proxy> WhenResolved { get; }
async Task<DeserializerState> CallImpl(ulong interfaceId, ushort methodId, async Task<DeserializerState> CallImpl(ulong interfaceId, ushort methodId, DynamicSerializerState args, CancellationToken cancellationToken)
DynamicSerializerState args, bool pipeline,
CancellationToken cancellationToken)
{ {
var cap = await WhenResolved; var cap = await WhenResolved;
@ -92,7 +90,7 @@ namespace Capnp.Rpc
if (cap == null) if (cap == null)
throw new RpcException("Broken capability"); throw new RpcException("Broken capability");
var call = cap.Call(interfaceId, methodId, args, pipeline); var call = cap.Call(interfaceId, methodId, args, default);
var whenReturned = call.WhenReturned; var whenReturned = call.WhenReturned;
using (var registration = cancellationToken.Register(call.Dispose)) using (var registration = cancellationToken.Register(call.Dispose))
@ -101,10 +99,10 @@ namespace Capnp.Rpc
} }
} }
internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{ {
var cts = new CancellationTokenSource(); var cts = new CancellationTokenSource();
return new LocalAnswer(cts, CallImpl(interfaceId, methodId, args, pipeline, cts.Token)); return new LocalAnswer(cts, CallImpl(interfaceId, methodId, args, cts.Token));
} }
} }
} }

View File

@ -57,9 +57,7 @@ namespace Capnp.Rpc
} }
} }
async Task<DeserializerState> CallImpl(ulong interfaceId, ushort methodId, async Task<DeserializerState> CallImpl(ulong interfaceId, ushort methodId, DynamicSerializerState args, CancellationToken cancellationToken)
DynamicSerializerState args, bool pipeline,
CancellationToken cancellationToken)
{ {
var cap = await AwaitResolved(); var cap = await AwaitResolved();
@ -68,7 +66,7 @@ namespace Capnp.Rpc
if (cap == null) if (cap == null)
throw new RpcException("Broken capability"); throw new RpcException("Broken capability");
var call = cap.Call(interfaceId, methodId, args, pipeline); var call = cap.Call(interfaceId, methodId, args, default);
var whenReturned = call.WhenReturned; var whenReturned = call.WhenReturned;
using (var registration = cancellationToken.Register(() => call.Dispose())) using (var registration = cancellationToken.Register(() => call.Dispose()))
@ -77,10 +75,10 @@ namespace Capnp.Rpc
} }
} }
internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{ {
var cts = new CancellationTokenSource(); var cts = new CancellationTokenSource();
return new LocalAnswer(cts, CallImpl(interfaceId, methodId, args, pipeline, cts.Token)); return new LocalAnswer(cts, CallImpl(interfaceId, methodId, args, cts.Token));
} }
protected override void ReleaseRemotely() protected override void ReleaseRemotely()

View File

@ -42,7 +42,7 @@ namespace Capnp.Rpc
ProvidedCap.Relinquish(); ProvidedCap.Relinquish();
} }
internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{ {
var cts = new CancellationTokenSource(); var cts = new CancellationTokenSource();
var call = ProvidedCap.Invoke(interfaceId, methodId, args, cts.Token); var call = ProvidedCap.Invoke(interfaceId, methodId, args, cts.Token);

View File

@ -38,40 +38,6 @@ namespace Capnp.Rpc
} }
} }
//public Task<SerializerState> WhenReady => ChainedAwaitWhenReady();
//public void Pipeline(PromisedAnswer.READER rd, Action<Proxy> action, Action<System.Exception> error)
//{
// lock (_reentrancyBlocker)
// {
// if (_chainedTask == null)
// {
// _chainedTask = InitialAwaitWhenReady();
// }
// _chainedTask = _chainedTask.ContinueWith(t =>
// {
// bool rethrow = true;
// try
// {
// t.Wait();
// rethrow = false;
// EvaluateProxyAndCallContinuation(rd, action);
// }
// catch (AggregateException aggregateException)
// {
// var innerException = aggregateException.InnerException;
// error(innerException);
// if (rethrow) throw innerException;
// }
// },
// TaskContinuationOptions.ExecuteSynchronously);
// }
//}
async Task AwaitChainedTask(Task chainedTask, Func<Task<AnswerOrCounterquestion>, Task> func) async Task AwaitChainedTask(Task chainedTask, Func<Task<AnswerOrCounterquestion>, Task> func)
{ {
try try
@ -197,83 +163,8 @@ namespace Capnp.Rpc
}); });
} }
//Task<SerializerState> ChainedAwaitWhenReady()
//{
// async Task<SerializerState> AwaitChainedTask(Task chainedTask)
// {
// await chainedTask;
// return _callTask.Result;
// }
// Task<SerializerState> resultTask;
// lock (_reentrancyBlocker)
// {
// if (_chainedTask == null)
// {
// _chainedTask = InitialAwaitWhenReady();
// }
// resultTask = AwaitChainedTask(_chainedTask);
// _chainedTask = resultTask;
// }
// return resultTask;
//}
public CancellationToken CancellationToken => _cts?.Token ?? CancellationToken.None; public CancellationToken CancellationToken => _cts?.Token ?? CancellationToken.None;
//void EvaluateProxyAndCallContinuation(PromisedAnswer.READER rd, Action<Proxy> action)
//{
// var result = _callTask.Result;
// DeserializerState cur = result;
// foreach (var op in rd.Transform)
// {
// switch (op.which)
// {
// case PromisedAnswer.Op.WHICH.GetPointerField:
// try
// {
// cur = cur.StructReadPointer(op.GetPointerField);
// }
// catch (System.Exception)
// {
// throw new ArgumentOutOfRangeException("Illegal pointer field in transformation operation");
// }
// break;
// case PromisedAnswer.Op.WHICH.Noop:
// break;
// default:
// throw new ArgumentOutOfRangeException("Unknown transformation operation");
// }
// }
// Proxy proxy;
// switch (cur.Kind)
// {
// case ObjectKind.Capability:
// try
// {
// var cap = result.MsgBuilder.Caps[(int)cur.CapabilityIndex];
// proxy = new Proxy(cap ?? LazyCapability.Null);
// }
// catch (ArgumentOutOfRangeException)
// {
// throw new ArgumentOutOfRangeException("Bad capability table in internal answer - internal error?");
// }
// action(proxy);
// break;
// default:
// throw new ArgumentOutOfRangeException("Transformation did not result in a capability");
// }
//}
public async void Dispose() public async void Dispose()
{ {
if (_cts != null) if (_cts != null)

View File

@ -169,13 +169,13 @@ namespace Capnp.Rpc
wr.ImportedCap = _remoteId; wr.ImportedCap = _remoteId;
} }
internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{ {
lock (_reentrancyBlocker) lock (_reentrancyBlocker)
{ {
if (_resolvedCap.Task.IsCompleted) if (_resolvedCap.Task.IsCompleted)
{ {
return CallOnResolution(interfaceId, methodId, args, pipeline); return CallOnResolution(interfaceId, methodId, args);
} }
else else
{ {
@ -184,7 +184,7 @@ namespace Capnp.Rpc
} }
} }
var promisedAnswer = base.DoCall(interfaceId, methodId, args, pipeline); var promisedAnswer = base.DoCall(interfaceId, methodId, args);
TrackCall(promisedAnswer.WhenReturned); TrackCall(promisedAnswer.WhenReturned);
return promisedAnswer; return promisedAnswer;
} }

View File

@ -66,13 +66,15 @@ namespace Capnp.Rpc
/// <param name="interfaceId">Interface ID to call</param> /// <param name="interfaceId">Interface ID to call</param>
/// <param name="methodId">Method ID to call</param> /// <param name="methodId">Method ID to call</param>
/// <param name="args">Method arguments ("param struct")</param> /// <param name="args">Method arguments ("param struct")</param>
/// <param name="tailCall">Whether it is a tail call</param> /// <param name="obsoleteAndIgnored">This flag is ignored. It is there to preserve compatibility with the
/// code generator and will be removed in future versions.</param>
/// <param name="cancellationToken">For cancelling an ongoing method call</param> /// <param name="cancellationToken">For cancelling an ongoing method call</param>
/// <returns>An answer promise</returns> /// <returns>An answer promise</returns>
/// <exception cref="ObjectDisposedException">This instance was disposed, or transport-layer stream was disposed.</exception> /// <exception cref="ObjectDisposedException">This instance was disposed, or transport-layer stream was disposed.</exception>
/// <exception cref="InvalidOperationException">Capability is broken.</exception> /// <exception cref="InvalidOperationException">Capability is broken.</exception>
/// <exception cref="System.IO.IOException">An I/O error occurs.</exception> /// <exception cref="System.IO.IOException">An I/O error occurs.</exception>
protected internal IPromisedAnswer Call(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall, CancellationToken cancellationToken = default) protected internal IPromisedAnswer Call(ulong interfaceId, ushort methodId, DynamicSerializerState args,
bool obsoleteAndIgnored, CancellationToken cancellationToken = default)
{ {
if (_disposedValue) if (_disposedValue)
throw new ObjectDisposedException(nameof(Proxy)); throw new ObjectDisposedException(nameof(Proxy));
@ -80,7 +82,7 @@ namespace Capnp.Rpc
if (ConsumedCap == null) if (ConsumedCap == null)
throw new InvalidOperationException("Cannot call null capability"); throw new InvalidOperationException("Cannot call null capability");
var answer = ConsumedCap.DoCall(interfaceId, methodId, args, tailCall); var answer = ConsumedCap.DoCall(interfaceId, methodId, args);
if (cancellationToken.CanBeCanceled) if (cancellationToken.CanBeCanceled)
{ {

View File

@ -99,7 +99,7 @@ namespace Capnp.Rpc
_access.Serialize(wr.PromisedAnswer); _access.Serialize(wr.PromisedAnswer);
} }
internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{ {
lock (_question.ReentrancyBlocker) lock (_question.ReentrancyBlocker)
{ {
@ -111,7 +111,7 @@ namespace Capnp.Rpc
throw new RpcException("Answer did not resolve to expected capability"); throw new RpcException("Answer did not resolve to expected capability");
} }
return CallOnResolution(interfaceId, methodId, args, pipeline); return CallOnResolution(interfaceId, methodId, args);
} }
else else
{ {
@ -130,7 +130,7 @@ namespace Capnp.Rpc
_question.DisallowFinish(); _question.DisallowFinish();
++_pendingCallsOnPromise; ++_pendingCallsOnPromise;
var promisedAnswer = base.DoCall(interfaceId, methodId, args, pipeline); var promisedAnswer = base.DoCall(interfaceId, methodId, args);
ReAllowFinishWhenDone(promisedAnswer.WhenReturned); ReAllowFinishWhenDone(promisedAnswer.WhenReturned);
async void DecrementPendingCallsOnPromiseWhenReturned() async void DecrementPendingCallsOnPromiseWhenReturned()

View File

@ -15,7 +15,7 @@ namespace Capnp.Rpc
_ep = ep; _ep = ep;
} }
internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall) internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{ {
var call = SetupMessage(args, interfaceId, methodId); var call = SetupMessage(args, interfaceId, methodId);
Debug.Assert(call.Target.which != MessageTarget.WHICH.undefined); Debug.Assert(call.Target.which != MessageTarget.WHICH.undefined);

View File

@ -29,7 +29,7 @@ namespace Capnp.Rpc
protected abstract void GetMessageTarget(MessageTarget.WRITER wr); protected abstract void GetMessageTarget(MessageTarget.WRITER wr);
protected IPromisedAnswer CallOnResolution(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool pipeline) protected IPromisedAnswer CallOnResolution(ulong interfaceId, ushort methodId, DynamicSerializerState args)
{ {
try try
{ {
@ -62,7 +62,7 @@ namespace Capnp.Rpc
#if DebugEmbargos #if DebugEmbargos
Logger.LogDebug("Direct call"); Logger.LogDebug("Direct call");
#endif #endif
return ResolvedCap.Call(interfaceId, methodId, args, pipeline); return ResolvedCap.Call(interfaceId, methodId, args, default);
} }
else else
{ {
@ -90,7 +90,7 @@ namespace Capnp.Rpc
cancellationTokenSource.Token.ThrowIfCancellationRequested(); cancellationTokenSource.Token.ThrowIfCancellationRequested();
return ResolvedCap.Call(interfaceId, methodId, args, pipeline); return ResolvedCap.Call(interfaceId, methodId, args, default);
}, TaskContinuationOptions.ExecuteSynchronously); }, TaskContinuationOptions.ExecuteSynchronously);

View File

@ -1311,7 +1311,7 @@ namespace Capnp.Rpc
} }
} }
public IReadOnlyList<ConsumedCapability> ImportCapTable(Payload.READER payload) public IList<ConsumedCapability> ImportCapTable(Payload.READER payload)
{ {
var list = new List<ConsumedCapability>(); var list = new List<ConsumedCapability>();

View File

@ -39,6 +39,9 @@ namespace Capnp.Rpc
internal static Skeleton GetOrCreateSkeleton<T>(T impl, bool addRef) internal static Skeleton GetOrCreateSkeleton<T>(T impl, bool addRef)
where T: class where T: class
{ {
if (impl == null)
throw new ArgumentNullException(nameof(impl));
if (impl is Skeleton skel) if (impl is Skeleton skel)
return skel; return skel;

View File

@ -32,7 +32,7 @@ namespace Capnp.Rpc
ulong interfaceId, ushort methodId, DeserializerState args, ulong interfaceId, ushort methodId, DeserializerState args,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
{ {
var promisedAnswer = Proxy.Call(interfaceId, methodId, (DynamicSerializerState)args, false); var promisedAnswer = Proxy.Call(interfaceId, methodId, (DynamicSerializerState)args, default);
if (promisedAnswer is PendingQuestion pendingQuestion && pendingQuestion.RpcEndpoint == Impatient.AskingEndpoint) if (promisedAnswer is PendingQuestion pendingQuestion && pendingQuestion.RpcEndpoint == Impatient.AskingEndpoint)
{ {