Interception base implementation

This commit is contained in:
Christian Köllner 2019-11-06 14:16:20 +01:00
parent 9e63e194bb
commit fd55167d39
14 changed files with 1019 additions and 42 deletions

View File

@ -9,11 +9,11 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<Compile Include="..\Capnp.Net.Runtime.Tests\CallContext.cs" Link="CallContext.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\DeserializationTests.cs" Link="DeserializationTests.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\DeserializationTests.cs" Link="DeserializationTests.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\DynamicSerializerStateTests.cs" Link="DynamicSerializerStateTests.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\DynamicSerializerStateTests.cs" Link="DynamicSerializerStateTests.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\FramePumpTests.cs" Link="FramePumpTests.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\FramePumpTests.cs" Link="FramePumpTests.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\General.cs" Link="General.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\General.cs" Link="General.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\Interception.cs" Link="Interception.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\Issue19.cs" Link="Issue19.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\Issue19.cs" Link="Issue19.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\Issue20.cs" Link="Issue20.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\Issue20.cs" Link="Issue20.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\Issue25.cs" Link="Issue25.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\Issue25.cs" Link="Issue25.cs" />
@ -30,6 +30,7 @@
<Compile Include="..\Capnp.Net.Runtime.Tests\TcpRpcStress.cs" Link="TcpRpcStress.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\TcpRpcStress.cs" Link="TcpRpcStress.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\test.cs" Link="test.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\test.cs" Link="test.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\TestBase.cs" Link="TestBase.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\TestBase.cs" Link="TestBase.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\TestCallContext.cs" Link="TestCallContext.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\TestCapImplementations.cs" Link="TestCapImplementations.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\TestCapImplementations.cs" Link="TestCapImplementations.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\TestInterfaces.cs" Link="TestInterfaces.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\TestInterfaces.cs" Link="TestInterfaces.cs" />
<Compile Include="..\Capnp.Net.Runtime.Tests\WirePointerTests.cs" Link="WirePointerTests.cs" /> <Compile Include="..\Capnp.Net.Runtime.Tests\WirePointerTests.cs" Link="WirePointerTests.cs" />

View File

@ -0,0 +1,587 @@
using Capnp.Net.Runtime.Tests.GenImpls;
using Capnp.Rpc;
using Capnp.Rpc.Interception;
using Capnproto_test.Capnp.Test;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
namespace Capnp.Net.Runtime.Tests
{
[TestClass]
public class Interception: TestBase
{
class MyPolicy : IInterceptionPolicy
{
readonly string _id;
readonly BufferBlock<CallContext> _callSubject = new BufferBlock<CallContext>();
readonly BufferBlock<CallContext> _returnSubject = new BufferBlock<CallContext>();
public MyPolicy(string id)
{
_id = id;
}
public bool Equals(IInterceptionPolicy other)
{
return other is MyPolicy myPolicy && _id.Equals(myPolicy._id);
}
public override bool Equals(object obj)
{
return obj is IInterceptionPolicy other && Equals(other);
}
public override int GetHashCode()
{
return _id.GetHashCode();
}
public void OnCallFromAlice(CallContext callContext)
{
Assert.IsTrue(_callSubject.Post(callContext));
}
public void OnReturnFromBob(CallContext callContext)
{
Assert.IsTrue(_returnSubject.Post(callContext));
}
public IReceivableSourceBlock<CallContext> Calls => _callSubject;
public IReceivableSourceBlock<CallContext> Returns => _returnSubject;
}
[TestMethod]
public void InterceptServerSideObserveCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = policy.Attach<ITestInterface>(new TestInterfaceImpl(counters));
using (var main = client.GetMain<ITestInterface>())
{
var request1 = main.Foo(123, true, default);
var fcc = policy.Calls.ReceiveAsync();
Assert.IsTrue(fcc.Wait(MediumNonDbgTimeout));
var cc = fcc.Result;
var pr = new Capnproto_test.Capnp.Test.TestInterface.Params_foo.READER(cc.InArgs);
Assert.AreEqual(123u, pr.I);
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
var rr = new Capnproto_test.Capnp.Test.TestInterface.Result_foo.READER(cc.OutArgs);
Assert.AreEqual("foo", rr.X);
cc.ReturnToAlice();
Assert.IsTrue(request1.Wait(MediumNonDbgTimeout));
Assert.AreEqual("foo", request1.Result);
}
}
}
[TestMethod]
public void InterceptClientSideModifyCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.AreEqual(InterceptionState.RequestedFromAlice, cc.State);
var pr = new Capnproto_test.Capnp.Test.TestInterface.Params_foo.READER(cc.InArgs);
Assert.AreEqual(321u, pr.I);
Assert.AreEqual(false, pr.J);
var pw = cc.InArgs.Rewrap<Capnproto_test.Capnp.Test.TestInterface.Params_foo.WRITER>();
pw.I = 123u;
pw.J = true;
cc.ForwardToBob();
var rx = policy.Returns.ReceiveAsync();
// Racing against Bob's answer
Assert.IsTrue(cc.State == InterceptionState.ForwardedToBob || rx.IsCompleted);
Assert.IsTrue(rx.Wait(MediumNonDbgTimeout));
var rr = new Capnproto_test.Capnp.Test.TestInterface.Result_foo.READER(cc.OutArgs);
Assert.AreEqual("foo", rr.X);
Assert.IsFalse(request1.IsCompleted);
var rw = ((DynamicSerializerState)cc.OutArgs).Rewrap<Capnproto_test.Capnp.Test.TestInterface.Result_foo.WRITER>();
rw.X = "bar";
cc.OutArgs = rw;
Assert.AreEqual(InterceptionState.ReturnedFromBob, cc.State);
cc.ReturnToAlice();
Assert.AreEqual(InterceptionState.ReturnedToAlice, cc.State);
Assert.IsTrue(request1.IsCompleted);
Assert.AreEqual("bar", request1.Result);
}
}
}
[TestMethod]
public void InterceptClientSideShortCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
var rw = SerializerState.CreateForRpc<Capnproto_test.Capnp.Test.TestInterface.Result_foo.WRITER>();
rw.X = "bar";
cc.OutArgs = rw;
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.AreEqual("bar", request1.Result);
}
}
}
[TestMethod]
public void InterceptClientSideRejectCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
cc.Exception = "rejected";
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.IsTrue(request1.IsFaulted);
Assert.AreEqual("rejected", request1.Exception.InnerException.Message);
}
}
}
[TestMethod]
public void InterceptClientSideCancelCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(321, false, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
cc.IsCanceled = true;
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.IsTrue(request1.IsCanceled);
}
}
}
[TestMethod]
public void InterceptClientSideRedirectCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var request1 = main.Foo(123, true, default);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsFalse(request1.IsCompleted);
var counters2 = new Counters();
var impl2 = new TestInterfaceImpl(counters2);
cc.Bob = impl2;
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
Assert.IsTrue(request1.IsCompleted);
Assert.AreEqual("foo", request1.Result);
Assert.AreEqual(0, counters.CallCount);
Assert.AreEqual(1, counters2.CallCount);
}
}
}
[TestMethod]
public void InterfaceAndMethodId()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestInterfaceImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestInterface>()))
{
var baz = main.Baz(new TestAllTypes());
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
Assert.IsTrue(cc.MethodId == 2);
Assert.AreEqual(new TestInterface_Skeleton().InterfaceId, cc.InterfaceId);
}
}
}
[TestMethod]
public void TailCall()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestTailCallerImpl(counters);
using (var main = client.GetMain<ITestTailCaller>())
{
var calleeCallCount = new Counters();
var callee = policy.Attach<ITestTailCallee>(new TestTailCalleeImpl(calleeCallCount));
var promise = main.Foo(456, callee, default);
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
var cc = ccf.Result;
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
Assert.IsTrue(promise.Wait(MediumNonDbgTimeout));
Assert.AreEqual("from TestTailCaller", promise.Result.T);
}
}
}
[TestMethod]
public void InterceptInCaps()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestMoreStuffImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestMoreStuff>()))
{
var counters2 = new Counters();
var cap = new TestInterfaceImpl(counters2);
var promise = main.CallFoo(cap);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.InterceptInCaps();
cc.ForwardToBob();
var cc2f = policy.Calls.ReceiveAsync();
Assert.IsTrue(cc2f.Wait(MediumNonDbgTimeout));
var cc2 = cc2f.Result;
cc2.ForwardToBob();
var cc2fr = policy.Returns.ReceiveAsync();
Assert.IsTrue(cc2fr.Wait(MediumNonDbgTimeout));
Assert.AreSame(cc2, cc2fr.Result);
Assert.AreEqual(1, counters2.CallCount);
cc2.ReturnToAlice();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
Assert.AreSame(cc, ccfr.Result);
cc.ReturnToAlice();
Assert.IsTrue(promise.Wait(MediumNonDbgTimeout));
}
}
}
[TestMethod]
public void InterceptOutCaps()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = policy.Attach<ITestMoreStuff>(new TestMoreStuffImpl(counters));
using (var main = client.GetMain<ITestMoreStuff>())
{
var counters2 = new Counters();
var cap = new TestInterfaceImpl(counters2);
main.Hold(cap);
{
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
ccf.Result.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
ccf.Result.ReturnToAlice();
}
var ghf = main.GetHeld();
{
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
ccf.Result.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
ccf.Result.InterceptOutCaps();
ccf.Result.ReturnToAlice();
}
Assert.IsTrue(ghf.Wait(MediumNonDbgTimeout));
var held = ghf.Result;
var foof = held.Foo(123, true);
{
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
ccf.Result.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
ccf.Result.ReturnToAlice();
}
Assert.IsTrue(foof.Wait(MediumNonDbgTimeout));
}
}
}
[TestMethod]
public void UninterceptOutCaps()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestMoreStuffImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestMoreStuff>()))
{
var counters2 = new Counters();
var cap = new TestInterfaceImpl(counters2);
main.Hold(cap);
{
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.InterceptInCaps();
cc.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
}
main.CallHeld();
{
// CallHeld
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.ForwardToBob();
// actual call on held cap.
var ccf = policy.Calls.ReceiveAsync();
Assert.IsTrue(ccf.Wait(MediumNonDbgTimeout));
ccf.Result.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
ccf.Result.ReturnToAlice();
ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
}
var ghf = main.GetHeld();
{
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.InterceptInCaps();
cc.ForwardToBob();
var ccfr = policy.Returns.ReceiveAsync();
Assert.IsTrue(ccfr.Wait(MediumNonDbgTimeout));
cc.UninterceptOutCaps();
cc.ReturnToAlice();
}
Assert.IsTrue(ghf.Wait(MediumNonDbgTimeout));
var held = ghf.Result;
var foof = held.Foo(123, true);
Assert.IsTrue(foof.Wait(MediumNonDbgTimeout));
}
}
}
[TestMethod]
public void UninterceptInCaps()
{
var policy = new MyPolicy("a");
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
var counters = new Counters();
server.Main = new TestMoreStuffImpl(counters);
using (var main = policy.Attach(client.GetMain<ITestMoreStuff>()))
{
var counters2 = new Counters();
var cap = policy.Attach<ITestInterface>(new TestInterfaceImpl(counters2));
var foof = main.CallFoo(cap);
Assert.IsTrue(policy.Calls.TryReceive(out var cc));
cc.UninterceptInCaps();
cc.ForwardToBob();
Assert.IsTrue(policy.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc.ReturnToAlice();
Assert.IsTrue(foof.Wait(MediumNonDbgTimeout));
}
}
}
[TestMethod]
public void MultiAttachAndDetach()
{
var a = new MyPolicy("a");
var b = new MyPolicy("b");
var c = new MyPolicy("c");
var counters = new Counters();
var impl = new TestInterfaceImpl(counters);
var implA = a.Attach<ITestInterface>(impl);
var implAbc = b.Attach(a.Attach(b.Attach(c.Attach(implA))));
var implAc = b.Detach(implAbc);
(var server, var client) = SetupClientServerPair();
using (server)
using (client)
{
client.WhenConnected.Wait();
server.Main = implAc;
using (var main = client.GetMain<ITestInterface>())
{
var foof = main.Foo(123, true);
var ccf1 = c.Calls.ReceiveAsync();
Assert.IsTrue(ccf1.Wait(MediumNonDbgTimeout));
var cc1 = ccf1.Result;
cc1.ForwardToBob();
var ccf2 = a.Calls.ReceiveAsync();
Assert.IsTrue(ccf2.Wait(MediumNonDbgTimeout));
var cc2 = ccf2.Result;
cc2.ForwardToBob();
Assert.IsTrue(a.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc2.ReturnToAlice();
Assert.IsTrue(c.Returns.ReceiveAsync().Wait(MediumNonDbgTimeout));
cc1.ReturnToAlice();
Assert.IsTrue(foof.Wait(MediumNonDbgTimeout));
}
}
}
}
}

View File

@ -8,16 +8,16 @@ namespace Capnp.Net.Runtime.Tests
{ {
class ProvidedCapabilityMultiCallMock : Skeleton class ProvidedCapabilityMultiCallMock : Skeleton
{ {
readonly BufferBlock<CallContext> _ccs = new BufferBlock<CallContext>(); readonly BufferBlock<TestCallContext> _ccs = new BufferBlock<TestCallContext>();
public override Task<AnswerOrCounterquestion> Invoke(ulong interfaceId, ushort methodId, public override Task<AnswerOrCounterquestion> Invoke(ulong interfaceId, ushort methodId,
DeserializerState args, CancellationToken cancellationToken = default(CancellationToken)) DeserializerState args, CancellationToken cancellationToken = default(CancellationToken))
{ {
var cc = new CallContext(interfaceId, methodId, args, cancellationToken); var cc = new TestCallContext(interfaceId, methodId, args, cancellationToken);
Assert.IsTrue(_ccs.Post(cc)); Assert.IsTrue(_ccs.Post(cc));
return cc.Result.Task; return cc.Result.Task;
} }
public Task<CallContext> WhenCalled => _ccs.ReceiveAsync(); public Task<TestCallContext> WhenCalled => _ccs.ReceiveAsync();
} }
} }

View File

@ -4,9 +4,9 @@ using Capnp.Rpc;
namespace Capnp.Net.Runtime.Tests namespace Capnp.Net.Runtime.Tests
{ {
class CallContext class TestCallContext
{ {
public CallContext(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken ct) public TestCallContext(ulong interfaceId, ushort methodId, DeserializerState args, CancellationToken ct)
{ {
InterfaceId = interfaceId; InterfaceId = interfaceId;
MethodId = methodId; MethodId = methodId;

View File

@ -47,7 +47,7 @@ namespace Capnp
/// <summary> /// <summary>
/// The capabilities imported from the capability table. Only valid in RPC context. /// The capabilities imported from the capability table. Only valid in RPC context.
/// </summary> /// </summary>
public IReadOnlyList<Rpc.ConsumedCapability> Caps { get; set; } public IList<Rpc.ConsumedCapability> Caps { get; set; }
/// <summary> /// <summary>
/// Current segment (essentially Segments[CurrentSegmentIndex] /// Current segment (essentially Segments[CurrentSegmentIndex]
/// </summary> /// </summary>

View File

@ -25,6 +25,12 @@ namespace Capnp
if (to == null) if (to == null)
throw new ArgumentNullException(nameof(to)); throw new ArgumentNullException(nameof(to));
if (from.Caps != null && to.Caps != null)
{
to.Caps.Clear();
to.Caps.AddRange(from.Caps);
}
var ds = to.Rewrap<DynamicSerializerState>(); var ds = to.Rewrap<DynamicSerializerState>();
IReadOnlyList<DeserializerState> items; IReadOnlyList<DeserializerState> items;

View File

@ -0,0 +1,238 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Capnp.Rpc.Interception
{
/// <summary>
/// Context of an intercepted call. Provides access to parameters and results,
/// and the possibility to redirect the call to some other capability.
/// </summary>
public class CallContext
{
class PromisedAnswer : IPromisedAnswer
{
CallContext _callContext;
TaskCompletionSource<DeserializerState> _futureResult = new TaskCompletionSource<DeserializerState>();
public PromisedAnswer(CallContext callContext)
{
_callContext = callContext;
}
public Task<DeserializerState> WhenReturned => _futureResult.Task;
async Task<Proxy> AccessWhenReturned(MemberAccessPath access)
{
await WhenReturned;
return new Proxy(Access(access));
}
public ConsumedCapability Access(MemberAccessPath access)
{
if (_futureResult.Task.IsCompleted)
{
try
{
return access.Eval(WhenReturned.Result);
}
catch (AggregateException exception)
{
throw exception.InnerException;
}
}
else
{
return new LazyCapability(AccessWhenReturned(access));
}
}
public void Dispose()
{
}
public void Return()
{
if (_callContext.IsCanceled)
{
_futureResult.SetCanceled();
}
else if (_callContext.Exception != null)
{
_futureResult.SetException(new RpcException(_callContext.Exception));
}
else
{
_futureResult.SetResult(_callContext.OutArgs);
}
}
}
public ulong InterfaceId { get; }
public ushort MethodId { get; }
public bool IsTailCall { get; }
public InterceptionState State { get; private set; }
public SerializerState InArgs { get; set; }
public DeserializerState OutArgs { get; set; }
public string Exception { get; set; }
public bool IsCanceled { get; set; }
public object Bob
{
get => _bob;
set
{
if (value != _bob)
{
BobProxy?.Dispose();
BobProxy = null;
_bob = value;
switch (value)
{
case Proxy proxy:
BobProxy = proxy;
break;
case Skeleton skeleton:
BobProxy = CapabilityReflection.CreateProxy<object>(
LocalCapability.Create(skeleton));
break;
case ConsumedCapability cap:
BobProxy = CapabilityReflection.CreateProxy<object>(cap);
break;
case null:
break;
default:
BobProxy = CapabilityReflection.CreateProxy<object>(
LocalCapability.Create(
Skeleton.GetOrCreateSkeleton(value, false)));
break;
}
}
}
}
internal Proxy BobProxy { get; private set; }
readonly CensorCapability _censorCapability;
PromisedAnswer _promisedAnswer;
object _bob;
internal IPromisedAnswer Answer => _promisedAnswer;
internal CallContext(CensorCapability censorCapability, ulong interfaceId, ushort methodId, SerializerState inArgs)
{
_censorCapability = censorCapability;
_promisedAnswer = new PromisedAnswer(this);
Bob = censorCapability.InterceptedCapability;
InterfaceId = interfaceId;
MethodId = methodId;
InArgs = inArgs;
State = InterceptionState.RequestedFromAlice;
}
static void InterceptCaps(DeserializerState state, IInterceptionPolicy policy)
{
if (state.Caps != null)
{
for (int i = 0; i < state.Caps.Count; i++)
{
state.Caps[i] = policy.Attach(state.Caps[i]);
state.Caps[i].AddRef();
}
}
}
static void UninterceptCaps(DeserializerState state, IInterceptionPolicy policy)
{
if (state.Caps != null)
{
for (int i = 0; i < state.Caps.Count; i++)
{
state.Caps[i] = policy.Detach(state.Caps[i]);
state.Caps[i].AddRef();
}
}
}
public void InterceptInCaps(IInterceptionPolicy policyOverride = null)
{
InterceptCaps(InArgs, policyOverride ?? _censorCapability.Policy);
}
public void InterceptOutCaps(IInterceptionPolicy policyOverride = null)
{
InterceptCaps(OutArgs, policyOverride ?? _censorCapability.Policy);
}
public void UninterceptInCaps(IInterceptionPolicy policyOverride = null)
{
UninterceptCaps(InArgs, policyOverride ?? _censorCapability.Policy);
}
public void UninterceptOutCaps(IInterceptionPolicy policyOverride = null)
{
UninterceptCaps(OutArgs, policyOverride ?? _censorCapability.Policy);
}
public void ForwardToBob(CancellationToken cancellationToken = default)
{
if (Bob == null)
{
throw new InvalidOperationException("Bob is null");
}
var answer = BobProxy.Call(InterfaceId, MethodId, InArgs.Rewrap<DynamicSerializerState>(), IsTailCall, cancellationToken);
State = InterceptionState.ForwardedToBob;
async void ChangeStateWhenReturned()
{
using (answer)
{
try
{
OutArgs = await answer.WhenReturned;
}
catch (TaskCanceledException)
{
IsCanceled = true;
}
catch (System.Exception exception)
{
Exception = exception.Message;
}
}
State = InterceptionState.ReturnedFromBob;
_censorCapability.Policy.OnReturnFromBob(this);
}
ChangeStateWhenReturned();
}
public void ReturnToAlice()
{
try
{
_promisedAnswer.Return();
}
catch (InvalidOperationException)
{
throw new InvalidOperationException("The call was already returned");
}
State = InterceptionState.ReturnedToAlice;
}
}
}

View File

@ -0,0 +1,44 @@
namespace Capnp.Rpc.Interception
{
class CensorCapability : RefCountingCapability
{
public CensorCapability(ConsumedCapability interceptedCapability, IInterceptionPolicy policy)
{
InterceptedCapability = interceptedCapability;
interceptedCapability.AddRef();
Policy = policy;
MyVine = Vine.Create(this);
}
public ConsumedCapability InterceptedCapability { get; }
public IInterceptionPolicy Policy { get; }
internal Skeleton MyVine { get; }
protected override void ReleaseRemotely()
{
InterceptedCapability.Release();
}
internal override IPromisedAnswer DoCall(ulong interfaceId, ushort methodId, DynamicSerializerState args, bool tailCall)
{
var cc = new CallContext(this, interfaceId, methodId, args);
Policy.OnCallFromAlice(cc);
return cc.Answer;
}
internal override void Export(IRpcEndpoint endpoint, CapDescriptor.WRITER writer)
{
writer.which = CapDescriptor.WHICH.SenderHosted;
writer.SenderHosted = endpoint.AllocateExport(MyVine, out bool _);
}
internal override void Freeze(out IRpcEndpoint boundEndpoint)
{
boundEndpoint = null;
}
internal override void Unfreeze()
{
}
}
}

View File

@ -0,0 +1,10 @@
using System;
namespace Capnp.Rpc.Interception
{
public interface IInterceptionPolicy: IEquatable<IInterceptionPolicy>
{
void OnCallFromAlice(CallContext callContext);
void OnReturnFromBob(CallContext callContext);
}
}

View File

@ -0,0 +1,28 @@
namespace Capnp.Rpc.Interception
{
/// <summary>
/// The state of an intercepted call from Alice to Bob.
/// </summary>
public enum InterceptionState
{
/// <summary>
/// Alice initiated the call, but it was neither forwarded to Bob nor finished.
/// </summary>
RequestedFromAlice,
/// <summary>
/// The call was forwarded to Bob.
/// </summary>
ForwardedToBob,
/// <summary>
/// The call returned from Bob (to whom it was forwarded), but no result was yet forwarded to Alice.
/// </summary>
ReturnedFromBob,
/// <summary>
/// The call was returned to Alice (either with results, exception, or cancelled)
/// </summary>
ReturnedToAlice
}
}

View File

@ -0,0 +1,94 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
namespace Capnp.Rpc.Interception
{
public static class Interceptor
{
static readonly ConditionalWeakTable<ConsumedCapability, CensorCapability> _interceptMap =
new ConditionalWeakTable<ConsumedCapability, CensorCapability>();
public static TCap Attach<TCap>(this IInterceptionPolicy policy, TCap cap)
where TCap: class
{
if (policy == null)
throw new ArgumentNullException(nameof(policy));
if (cap == null)
throw new ArgumentNullException(nameof(cap));
var cur = cap as CensorCapability;
while (cur != null)
{
if (policy.Equals(cur.Policy))
{
return cap;
}
cur = cur.InterceptedCapability as CensorCapability;
}
switch (cap)
{
case Proxy proxy:
return CapabilityReflection.CreateProxy<TCap>(Attach(policy, proxy.ConsumedCap)) as TCap;
case ConsumedCapability ccap:
return new CensorCapability(ccap, policy) as TCap;
default:
return Attach(policy,
CapabilityReflection.CreateProxy<TCap>(
LocalCapability.Create(
Skeleton.GetOrCreateSkeleton(cap, false))) as TCap);
}
}
public static TCap Detach<TCap>(this IInterceptionPolicy policy, TCap cap)
where TCap: class
{
if (policy == null)
throw new ArgumentNullException(nameof(policy));
if (cap == null)
throw new ArgumentNullException(nameof(cap));
switch (cap)
{
case Proxy proxy:
return CapabilityReflection.CreateProxy<TCap>(Detach(policy, proxy.ConsumedCap)) as TCap;
case CensorCapability ccap:
{
var cur = ccap;
var stk = new Stack<IInterceptionPolicy>();
do
{
if (policy.Equals(cur.Policy))
{
var cur2 = cur.InterceptedCapability;
foreach (var p in stk)
{
cur2 = p.Attach(cur2);
}
return cur2 as TCap;
}
stk.Push(cur.Policy);
cur = cur.InterceptedCapability as CensorCapability;
}
while (cur != null);
return ccap as TCap;
}
default:
return cap;
}
}
}
}

View File

@ -38,40 +38,6 @@ namespace Capnp.Rpc
} }
} }
//public Task<SerializerState> WhenReady => ChainedAwaitWhenReady();
//public void Pipeline(PromisedAnswer.READER rd, Action<Proxy> action, Action<System.Exception> error)
//{
// lock (_reentrancyBlocker)
// {
// if (_chainedTask == null)
// {
// _chainedTask = InitialAwaitWhenReady();
// }
// _chainedTask = _chainedTask.ContinueWith(t =>
// {
// bool rethrow = true;
// try
// {
// t.Wait();
// rethrow = false;
// EvaluateProxyAndCallContinuation(rd, action);
// }
// catch (AggregateException aggregateException)
// {
// var innerException = aggregateException.InnerException;
// error(innerException);
// if (rethrow) throw innerException;
// }
// },
// TaskContinuationOptions.ExecuteSynchronously);
// }
//}
async Task AwaitChainedTask(Task chainedTask, Func<Task<AnswerOrCounterquestion>, Task> func) async Task AwaitChainedTask(Task chainedTask, Func<Task<AnswerOrCounterquestion>, Task> func)
{ {
try try

View File

@ -1311,7 +1311,7 @@ namespace Capnp.Rpc
} }
} }
public IReadOnlyList<ConsumedCapability> ImportCapTable(Payload.READER payload) public IList<ConsumedCapability> ImportCapTable(Payload.READER payload)
{ {
var list = new List<ConsumedCapability>(); var list = new List<ConsumedCapability>();

View File

@ -39,6 +39,9 @@ namespace Capnp.Rpc
internal static Skeleton GetOrCreateSkeleton<T>(T impl, bool addRef) internal static Skeleton GetOrCreateSkeleton<T>(T impl, bool addRef)
where T: class where T: class
{ {
if (impl == null)
throw new ArgumentNullException(nameof(impl));
if (impl is Skeleton skel) if (impl is Skeleton skel)
return skel; return skel;