diff --git a/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj b/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj index b72dab4..42ee749 100644 --- a/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj +++ b/Capnp.Net.Runtime.Tests.Core21/Capnp.Net.Runtime.Tests.Core21.csproj @@ -19,6 +19,7 @@ + @@ -28,6 +29,7 @@ + @@ -44,6 +46,7 @@ + diff --git a/Capnp.Net.Runtime.Tests/Capnp.Net.Runtime.Tests.Std20.csproj b/Capnp.Net.Runtime.Tests/Capnp.Net.Runtime.Tests.Std20.csproj index 1fdeec0..cb369cc 100644 --- a/Capnp.Net.Runtime.Tests/Capnp.Net.Runtime.Tests.Std20.csproj +++ b/Capnp.Net.Runtime.Tests/Capnp.Net.Runtime.Tests.Std20.csproj @@ -24,6 +24,7 @@ + diff --git a/Capnp.Net.Runtime.Tests/LocalRpc.cs b/Capnp.Net.Runtime.Tests/LocalRpc.cs new file mode 100644 index 0000000..232313b --- /dev/null +++ b/Capnp.Net.Runtime.Tests/LocalRpc.cs @@ -0,0 +1,31 @@ +using Capnp.Net.Runtime.Tests.GenImpls; +using Capnp.Rpc; +using Capnproto_test.Capnp.Test; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Capnp.Net.Runtime.Tests +{ + [TestClass] + public class LocalRpc + { + [TestMethod] + public void DeferredLocalAnswer() + { + var tcs = new TaskCompletionSource(); + var impl = new TestPipelineImpl2(tcs.Task); + var bproxy = BareProxy.FromImpl(impl); + var proxy = bproxy.Cast(true); + var cap = proxy.GetCap(0, null).OutBox_Cap(); + var foo = cap.Foo(123, true); + tcs.SetResult(0); + Assert.IsTrue(foo.Wait(TestBase.MediumNonDbgTimeout)); + Assert.AreEqual("bar", foo.Result); + } + } +} diff --git a/Capnp.Net.Runtime.Tests/TcpRpcErrorHandling.cs b/Capnp.Net.Runtime.Tests/TcpRpcErrorHandling.cs new file mode 100644 index 0000000..163d85e --- /dev/null +++ b/Capnp.Net.Runtime.Tests/TcpRpcErrorHandling.cs @@ -0,0 +1,608 @@ +using Capnp; +using Capnp.Net.Runtime.Tests.GenImpls; +using Capnp.Rpc; +using Capnproto_test.Capnp.Test; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Capnp.Net.Runtime.Tests +{ + [TestClass] + public class TcpRpcErrorHandling: TestBase + { + class MemStreamEndpoint : IEndpoint + { + readonly FramePump _fromEnginePump; + readonly BinaryReader _reader; + + public bool Dismissed { get; private set; } + + public MemStreamEndpoint() + { + var pipe = new Pipe(); + _fromEnginePump = new FramePump(pipe.Writer.AsStream()); + _reader = new BinaryReader(pipe.Reader.AsStream()); + } + + public void Dismiss() + { + Dismissed = true; + } + + public void Forward(WireFrame frame) + { + _fromEnginePump.Send(frame); + } + + public WireFrame ReadNextFrame() + { + return _reader.ReadWireFrame(); + } + } + + class RpcEngineTester + { + readonly MemStreamEndpoint _fromEngine; + + public RpcEngineTester() + { + Engine = new RpcEngine(); + _fromEngine = new MemStreamEndpoint(); + RealEnd = Engine.AddEndpoint(_fromEngine); + } + + public RpcEngine Engine { get; } + public RpcEngine.RpcEndpoint RealEnd { get; } + public bool IsDismissed => _fromEngine.Dismissed; + + public void Send(Action build) + { + var mb = MessageBuilder.Create(); + mb.InitCapTable(); + build(mb.BuildRoot()); + RealEnd.Forward(mb.Frame); + } + + public void Recv(Action verify) + { + var task = Task.Run(() => DeserializerState.CreateRoot(_fromEngine.ReadNextFrame())); + Assert.IsTrue(task.Wait(MediumNonDbgTimeout), "reception timeout"); + verify(new Message.READER(task.Result)); + } + + public void ExpectAbort() + { + Recv(_ => { Assert.AreEqual(Message.WHICH.Abort, _.which); }); + Assert.IsTrue(IsDismissed); + Assert.ThrowsException( + () => Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 33; })); + } + } + + [TestMethod] + public void DuplicateQuestion1() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 99; }); + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 99; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void DuplicateQuestion2() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 99; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 99; + _.Call.Target.which = MessageTarget.WHICH.ImportedCap; + _.Call.Target.ImportedCap = bootCapId; + _.Call.InterfaceId = ((TypeIdAttribute)typeof(ITestInterface).GetCustomAttributes(typeof(TypeIdAttribute), false)[0]).Id; + _.Call.MethodId = 0; + _.Call.Params.Content.Rewrap(); + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void DuplicateQuestion3() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 99; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 42; + _.Call.Target.which = MessageTarget.WHICH.ImportedCap; + _.Call.Target.ImportedCap = bootCapId; + _.Call.InterfaceId = ((TypeIdAttribute)typeof(ITestInterface).GetCustomAttributes(typeof(TypeIdAttribute), false)[0]).Id; + _.Call.MethodId = 0; + var wr = _.Call.Params.Content.Rewrap(); + wr.I = 123u; + wr.J = true; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 42; + _.Call.Target.which = MessageTarget.WHICH.ImportedCap; + _.Call.Target.ImportedCap = bootCapId; + _.Call.InterfaceId = ((TypeIdAttribute)typeof(ITestInterface).GetCustomAttributes(typeof(TypeIdAttribute), false)[0]).Id; + _.Call.MethodId = 0; + _.Call.Params.Content.Rewrap(); + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void NoBootstrap() + { + var tester = new RpcEngineTester(); + + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 0; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Exception, _.Return.which); + }); + Assert.IsFalse(tester.IsDismissed); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + tester.Send(_ => { _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 1; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + }); + } + + [TestMethod] + public void DuplicateFinish() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { + _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 99; }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Finish; + _.Finish.QuestionId = 99; + }); + tester.Send(_ => { + _.which = Message.WHICH.Finish; + _.Finish.QuestionId = 99; + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void DuplicateAnswer() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.SenderHosted; + _.Return.Results.CapTable[0].SenderHosted = 1; + }); + Assert.IsTrue(proxy.WhenResolved.IsCompleted); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Finish, _.which); + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.SenderHosted; + _.Return.Results.CapTable[0].SenderHosted = 1; + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void InvalidReceiverHosted() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.ReceiverHosted; + _.Return.Results.CapTable[0].ReceiverHosted = 0; + }); + Assert.IsTrue(proxy.WhenResolved.IsCompleted); + tester.ExpectAbort(); + } + + [TestMethod] + public void InvalidReceiverAnswer() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.ReceiverAnswer; + _.Return.Results.CapTable[0].ReceiverAnswer.QuestionId = 0; + _.Return.Results.CapTable[0].ReceiverAnswer.Transform.Init(1); + _.Return.Results.CapTable[0].ReceiverAnswer.Transform[0].which = PromisedAnswer.Op.WHICH.GetPointerField; + _.Return.Results.CapTable[0].ReceiverAnswer.Transform[0].GetPointerField = 0; + }); + Assert.IsTrue(proxy.WhenResolved.IsCompleted); + tester.ExpectAbort(); + } + + [TestMethod] + public void DuplicateResolve() + { + var tester = new RpcEngineTester(); + + var cap = tester.RealEnd.QueryMain(); + var proxy = new BareProxy(cap); + Assert.IsFalse(proxy.WhenResolved.IsCompleted); + uint id = 0; + + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Bootstrap, _.which); + id = _.Bootstrap.QuestionId; + }); + tester.Send(_ => { + _.which = Message.WHICH.Return; + _.Return.which = Return.WHICH.Results; + _.Return.Results.CapTable.Init(1); + _.Return.Results.CapTable[0].which = CapDescriptor.WHICH.SenderPromise; + _.Return.Results.CapTable[0].SenderPromise = 0; + }); + tester.Send(_ => { + _.which = Message.WHICH.Resolve; + _.Resolve.which = Resolve.WHICH.Cap; + _.Resolve.Cap.which = CapDescriptor.WHICH.SenderHosted; + _.Resolve.Cap.SenderHosted = 1; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Finish, _.which); + }); + tester.Send(_ => { + _.which = Message.WHICH.Resolve; + _.Resolve.which = Resolve.WHICH.Exception; + _.Resolve.Exception.Reason = "problem"; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Release, _.which); + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void DuplicateRelease1() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { + _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 99; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Release; + _.Release.Id = bootCapId; + _.Release.ReferenceCount = 1; + }); + tester.Send(_ => { + _.which = Message.WHICH.Release; + _.Release.Id = bootCapId; + _.Release.ReferenceCount = 1; + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void DuplicateRelease2() + { + var tester = new RpcEngineTester(); + tester.Engine.Main = new TestInterfaceImpl(new Counters()); + + uint bootCapId = 0; + + tester.Send(_ => { + _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 99; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Release; + _.Release.Id = bootCapId; + _.Release.ReferenceCount = 2; + }); + tester.ExpectAbort(); + } + + [TestMethod] + public void UnimplementedAccept() + { + var tester = new RpcEngineTester(); + + tester.Send(_ => { + _.which = Message.WHICH.Accept; + _.Accept.Embargo = true; + _.Accept.QuestionId = 47; + _.Accept.Provision.SetStruct(1, 0); + _.Accept.Provision.Allocate(); + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + Assert.AreEqual(Message.WHICH.Accept, _.Unimplemented.which); + Assert.IsTrue(_.Unimplemented.Accept.Embargo); + Assert.AreEqual(47u, _.Unimplemented.Accept.QuestionId); + Assert.AreEqual(1, _.Unimplemented.Accept.Provision.StructDataCount); + Assert.AreEqual(0, _.Unimplemented.Accept.Provision.StructPtrCount); + }); + Assert.IsFalse(tester.IsDismissed); + } + + [TestMethod] + public void UnimplementedJoin() + { + var tester = new RpcEngineTester(); + + tester.Send(_ => { + _.which = Message.WHICH.Join; + _.Join.QuestionId = 74; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + Assert.AreEqual(Message.WHICH.Join, _.Unimplemented.which); + Assert.AreEqual(74u, _.Unimplemented.Join.QuestionId); + }); + Assert.IsFalse(tester.IsDismissed); + } + + [TestMethod] + public void UnimplementedProvide() + { + var tester = new RpcEngineTester(); + + tester.Send(_ => { + _.which = Message.WHICH.Provide; + _.Provide.QuestionId = 666; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + Assert.AreEqual(Message.WHICH.Provide, _.Unimplemented.which); + Assert.AreEqual(666u, _.Unimplemented.Provide.QuestionId); + }); + Assert.IsFalse(tester.IsDismissed); + } + + [TestMethod] + public void UnimplementedObsoleteDelete() + { + var tester = new RpcEngineTester(); + + tester.Send(_ => { + _.which = Message.WHICH.ObsoleteDelete; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + Assert.AreEqual(Message.WHICH.ObsoleteDelete, _.Unimplemented.which); + }); + Assert.IsFalse(tester.IsDismissed); + } + + [TestMethod] + public void UnimplementedObsoleteSave() + { + var tester = new RpcEngineTester(); + + tester.Send(_ => { + _.which = Message.WHICH.ObsoleteSave; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + Assert.AreEqual(Message.WHICH.ObsoleteSave, _.Unimplemented.which); + }); + Assert.IsFalse(tester.IsDismissed); + } + + [TestMethod] + public void UnimplementedUnknown() + { + var tester = new RpcEngineTester(); + + tester.Send(_ => { + _.which = (Message.WHICH)123; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Unimplemented, _.which); + Assert.AreEqual((Message.WHICH)123, _.Unimplemented.which); + }); + Assert.IsFalse(tester.IsDismissed); + } + + class TestPipelineImpl3 : ITestPipeline + { + readonly TestPipelineImpl2 _impl; + readonly ITestPipeline _proxy; + + public TestPipelineImpl3(Task complete) + { + _impl = new TestPipelineImpl2(complete); + var bproxy = BareProxy.FromImpl(_impl); + _proxy = bproxy.Cast(true); + } + + public void Dispose() + { + } + + public bool IsGrandsonCapDisposed => _impl.IsChildCapDisposed; + + public Task<(string, TestPipeline.AnyBox)> GetAnyCap(uint n, BareProxy inCap, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task<(string, TestPipeline.Box)> GetCap(uint n, ITestInterface inCap, CancellationToken cancellationToken_ = default) + { + return Task.FromResult(("foo", new TestPipeline.Box() { Cap = _proxy.GetCap(0, null).OutBox_Cap() })); + } + + public Task TestPointers(ITestInterface cap, object obj, IReadOnlyList list, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + } + + [TestMethod] + public void UnimplementedResolve() + { + var tcs = new TaskCompletionSource(); + var tester = new RpcEngineTester(); + var impl = new TestPipelineImpl3(tcs.Task); + tester.Engine.Main = impl; + + uint bootCapId = 0; + + tester.Send(_ => { + _.which = Message.WHICH.Bootstrap; _.Bootstrap.QuestionId = 99; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + bootCapId = _.Return.Results.CapTable[0].SenderHosted; + }); + tester.Send(_ => { + _.which = Message.WHICH.Call; + _.Call.QuestionId = 42; + _.Call.Target.which = MessageTarget.WHICH.ImportedCap; + _.Call.Target.ImportedCap = bootCapId; + _.Call.InterfaceId = ((TypeIdAttribute)typeof(ITestPipeline).GetCustomAttributes(typeof(TypeIdAttribute), false)[0]).Id; + _.Call.MethodId = 0; + var wr = _.Call.Params.Content.Rewrap(); + wr.InCap = null; + _.Call.Params.CapTable.Init(1); + _.Call.Params.CapTable[0].which = CapDescriptor.WHICH.ReceiverHosted; + }); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Return, _.which); + Assert.AreEqual(Return.WHICH.Results, _.Return.which); + Assert.AreEqual(1, _.Return.Results.CapTable.Count); + Assert.AreEqual(CapDescriptor.WHICH.SenderPromise, _.Return.Results.CapTable[0].which); + }); + tcs.SetResult(0); + tester.Recv(_ => { + Assert.AreEqual(Message.WHICH.Resolve, _.which); + Assert.AreEqual(Resolve.WHICH.Cap, _.Resolve.which); + Assert.AreEqual(CapDescriptor.WHICH.SenderHosted, _.Resolve.Cap.which); + + Assert.IsFalse(impl.IsGrandsonCapDisposed); + + tester.Send(_1 => + { + _1.which = Message.WHICH.Unimplemented; + _1.Unimplemented.which = Message.WHICH.Resolve; + Reserializing.DeepCopy(_, _1.Unimplemented.Resolve); + }); + + Assert.IsFalse(impl.IsGrandsonCapDisposed); + + tester.Send(_1 => + { + _1.which = Message.WHICH.Finish; + _1.Finish.QuestionId = 42; + _1.Finish.ReleaseResultCaps = true; + }); + + Assert.IsTrue(impl.IsGrandsonCapDisposed); + }); + Assert.IsFalse(tester.IsDismissed); + } + } +} diff --git a/Capnp.Net.Runtime.Tests/TestCapImplementations.cs b/Capnp.Net.Runtime.Tests/TestCapImplementations.cs index 029ac0f..44cfde2 100644 --- a/Capnp.Net.Runtime.Tests/TestCapImplementations.cs +++ b/Capnp.Net.Runtime.Tests/TestCapImplementations.cs @@ -432,6 +432,33 @@ namespace Capnp.Net.Runtime.Tests.GenImpls } } + class TestInterfaceImpl2 : ITestInterface + { + public Task Bar(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task Baz(TestAllTypes s, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public void Dispose() + { + IsDisposed = true; + } + + public bool IsDisposed { get; private set; } + + public Task Foo(uint i, bool j, CancellationToken cancellationToken_ = default) + { + Assert.AreEqual(123u, i); + Assert.IsTrue(j); + return Task.FromResult("bar"); + } + } + #endregion TestInterface #region TestExtends @@ -506,6 +533,41 @@ namespace Capnp.Net.Runtime.Tests.GenImpls throw new NotImplementedException(); } } + + class TestPipelineImpl2 : ITestPipeline + { + readonly Task _deblock; + readonly TestInterfaceImpl2 _timpl2; + + public TestPipelineImpl2(Task deblock) + { + _deblock = deblock; + _timpl2 = new TestInterfaceImpl2(); + } + + public void Dispose() + { + } + + public bool IsChildCapDisposed => _timpl2.IsDisposed; + + public Task<(string, TestPipeline.AnyBox)> GetAnyCap(uint n, BareProxy inCap, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public async Task<(string, TestPipeline.Box)> GetCap(uint n, ITestInterface inCap, CancellationToken cancellationToken_ = default) + { + await _deblock; + return ("hello", new TestPipeline.Box() { Cap = _timpl2 }); + } + + public Task TestPointers(ITestInterface cap, object obj, IReadOnlyList list, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + } + #endregion TestPipeline #region TestCallOrder diff --git a/Capnp.Net.Runtime/Rpc/RpcEngine.cs b/Capnp.Net.Runtime/Rpc/RpcEngine.cs index 2378959..a11677d 100644 --- a/Capnp.Net.Runtime/Rpc/RpcEngine.cs +++ b/Capnp.Net.Runtime/Rpc/RpcEngine.cs @@ -60,8 +60,14 @@ namespace Capnp.Rpc } } - internal class RpcEndpoint : IEndpoint, IRpcEndpoint + public class RpcEndpoint : IEndpoint, IRpcEndpoint { + public enum EndpointState + { + Active, + Dismissed + } + static readonly ThreadLocal _exportCapTablePostActions = new ThreadLocal(); static readonly ThreadLocal _tailCall = new ThreadLocal(); static readonly ThreadLocal _canDeferCalls = new ThreadLocal(); @@ -87,8 +93,11 @@ namespace Capnp.Rpc { _host = host; _tx = tx; + State = EndpointState.Active; } + public EndpointState State { get; private set; } + public void Dismiss() { lock (_reentrancyBlocker) @@ -105,6 +114,8 @@ namespace Capnp.Rpc _answerTable.Clear(); _pendingDisembargos.Clear(); + + State = EndpointState.Dismissed; } _tx.Dismiss(); @@ -112,6 +123,9 @@ namespace Capnp.Rpc public void Forward(WireFrame frame) { + if (State == EndpointState.Dismissed) + throw new InvalidOperationException("Endpoint is in dismissed state and doesn't accept frames anymore"); + Interlocked.Increment(ref _recvCount); ProcessFrame(frame); } @@ -327,6 +341,7 @@ namespace Capnp.Rpc if (!added) { Logger.LogWarning("Incoming bootstrap request: Peer specified duplicate (not yet released?) answer ID."); + throw new RpcProtocolErrorException("Duplicate question ID"); } @@ -390,9 +405,7 @@ namespace Capnp.Rpc pendingAnswer.Cancel(); pendingAnswer.Dispose(); - SendAbort($"There is another pending answer for the same question ID {req.QuestionId}."); - - return; + throw new RpcProtocolErrorException($"There is another pending answer for the same question ID {req.QuestionId}."); } switch (req.SendResultsTo.which) @@ -553,8 +566,7 @@ namespace Capnp.Rpc { Logger.LogWarning("Incoming RPC call: Peer asked for invalid (already released?) capability ID."); - SendAbort($"Requested capability with ID {req.Target.ImportedCap} does not exist."); - return; + throw new RpcProtocolErrorException($"Requested capability with ID {req.Target.ImportedCap} does not exist."); } } @@ -606,8 +618,7 @@ namespace Capnp.Rpc else { Logger.LogWarning("Incoming RPC call: Peer asked for non-existing answer ID."); - SendAbort($"Did not find a promised answer for given ID {req.Target.PromisedAnswer.QuestionId}"); - return; + throw new RpcProtocolErrorException($"Did not find a promised answer for given ID {req.Target.PromisedAnswer.QuestionId}"); } } break; @@ -637,7 +648,7 @@ namespace Capnp.Rpc { Logger.LogWarning("Incoming RPC return: Unknown answer ID."); - return; + throw new RpcProtocolErrorException("Unknown answer ID"); } } @@ -722,7 +733,7 @@ namespace Capnp.Rpc if (!_importTable.TryGetValue(resolve.PromiseId, out var rcw)) { Logger.LogWarning("Received a resolve message with invalid ID"); - return; + throw new RpcProtocolErrorException($"Invalid ID {resolve.PromiseId}"); } if (!rcw.Cap.TryGetTarget(out var cap)) @@ -734,28 +745,33 @@ namespace Capnp.Rpc if (!(cap is PromisedCapability resolvableCap)) { Logger.LogWarning("Received a resolve message for a capability which is not a promise"); - return; + throw new RpcProtocolErrorException($"Not a promise {resolve.PromiseId}"); } - switch (resolve.which) + try { - case Resolve.WHICH.Cap: - lock (_reentrancyBlocker) - { - var resolvedCap = ImportCap(resolve.Cap); - if (resolvedCap == null) - resolvedCap = LazyCapability.CreateBrokenCap("Failed to resolve this capability"); - resolvableCap.ResolveTo(resolvedCap); - } - break; + switch (resolve.which) + { + case Resolve.WHICH.Cap: + lock (_reentrancyBlocker) + { + var resolvedCap = ImportCap(resolve.Cap); + resolvableCap.ResolveTo(resolvedCap); + } + break; - case Resolve.WHICH.Exception: - resolvableCap.Break(resolve.Exception.Reason); - break; + case Resolve.WHICH.Exception: + resolvableCap.Break(resolve.Exception.Reason); + break; - default: - Logger.LogWarning("Received a resolve message with unknown category."); - throw new RpcUnimplementedException(); + default: + Logger.LogWarning("Received a resolve message with unknown category."); + throw new RpcUnimplementedException(); + } + } + catch (InvalidOperationException) + { + throw new RpcProtocolErrorException($"Capability {resolve.PromiseId} was already resolved"); } } @@ -777,10 +793,7 @@ namespace Capnp.Rpc { Logger.LogWarning("Sender loopback request: Peer asked for invalid (already released?) capability ID."); - SendAbort("'Disembargo': Invalid capability ID"); - Dismiss(); - - return; + throw new RpcProtocolErrorException("'Disembargo': Invalid capability ID"); } reply.Target.which = MessageTarget.WHICH.ImportedCap; @@ -830,8 +843,7 @@ namespace Capnp.Rpc { Logger.LogWarning("Sender loopback request: Peer asked for disembargoing an answer which does not resolve back to the sender."); - SendAbort("'Disembargo': Answer does not resolve back to me"); - Dismiss(); + throw new RpcProtocolErrorException("'Disembargo': Answer does not resolve back to me"); } } finally @@ -844,9 +856,7 @@ namespace Capnp.Rpc { Logger.LogWarning($"Sender loopback request: Peer asked for disembargoing an answer which either has not yet returned, was canceled, or faulted: {exception.Message}"); - SendAbort($"'Disembargo' failure: {exception}"); - Dismiss(); - + throw new RpcProtocolErrorException($"'Disembargo' failure: {exception}"); } }); } @@ -854,10 +864,7 @@ namespace Capnp.Rpc { Logger.LogWarning("Sender loopback request: Peer asked for non-existing answer ID."); - SendAbort("'Disembargo': Invalid answer ID"); - Dismiss(); - - return; + throw new RpcProtocolErrorException("'Disembargo': Invalid answer ID"); } break; @@ -966,6 +973,8 @@ namespace Capnp.Rpc else { Logger.LogWarning("Peer sent 'finish' message with unknown question ID"); + + throw new RpcProtocolErrorException("unknown question ID"); } } @@ -994,6 +1003,8 @@ namespace Capnp.Rpc catch (System.Exception exception) { Logger.LogWarning($"Attempting to release capability with invalid reference count: {exception.Message}"); + + throw new RpcProtocolErrorException("Invalid reference count"); } } } @@ -1001,6 +1012,8 @@ namespace Capnp.Rpc if (!exists) { Logger.LogWarning("Attempting to release unknown capability ID"); + + throw new RpcProtocolErrorException("Invalid export ID"); } } @@ -1056,9 +1069,7 @@ namespace Capnp.Rpc //# In cases where there is no sensible way to react to an `unimplemented` message (without //# resource leaks or other serious problems), the connection may need to be aborted. This is //# a gray area; different implementations may take different approaches. - SendAbort("It's hopeless if you don't implement the bootstrap message"); - Dismiss(); - break; + throw new RpcProtocolErrorException("It's hopeless if you don't implement the bootstrap message"); default: // Looking at the various message types it feels OK to just not care about other 'unimplemented' @@ -1156,6 +1167,11 @@ namespace Capnp.Rpc Tx(mb.Frame); } + catch (RpcProtocolErrorException error) + { + SendAbort(error.Message); + Dismiss(); + } catch (System.Exception exception) { Logger.LogError(exception, "Uncaught exception during message processing."); @@ -1164,12 +1180,11 @@ namespace Capnp.Rpc // First, we send implementation specific details of a - maybe internal - error, not very valuable for the // receiver. But worse: From a security point of view, we might even reveil a secret here. SendAbort("Uncaught exception during RPC processing. This may also happen due to invalid input."); - Dismiss(); } } - ConsumedCapability? ImportCap(CapDescriptor.READER capDesc) + ConsumedCapability ImportCap(CapDescriptor.READER capDesc) { lock (_reentrancyBlocker) { @@ -1198,7 +1213,6 @@ namespace Capnp.Rpc rcw = new RefCounted>( new WeakReference(newCap)); _importTable.Add(capDesc.SenderHosted, rcw); - Debug.Assert(newCap != null); return newCap; } @@ -1217,7 +1231,6 @@ namespace Capnp.Rpc rcwp.Cap.SetTarget(impCap); } - Debug.Assert(impCap != null); return impCap; } else @@ -1226,7 +1239,6 @@ namespace Capnp.Rpc rcw = new RefCounted>( new WeakReference(newCap)); _importTable.Add(capDesc.SenderPromise, rcw); - Debug.Assert(newCap != null); return newCap; } @@ -1238,7 +1250,7 @@ namespace Capnp.Rpc else { Logger.LogWarning("Peer refers to receiver-hosted capability which does not exist."); - return null; + throw new RpcProtocolErrorException($"Receiver-hosted capability {capDesc.ReceiverHosted} does not exist."); } case CapDescriptor.WHICH.ReceiverAnswer: @@ -1270,7 +1282,7 @@ namespace Capnp.Rpc else { Logger.LogWarning("Peer refers to pending answer which does not exist."); - return null; + throw new RpcProtocolErrorException($"Invalid question ID {capDesc.ReceiverAnswer.QuestionId}"); } case CapDescriptor.WHICH.ThirdPartyHosted: @@ -1288,7 +1300,6 @@ namespace Capnp.Rpc rcv.Cap.SetTarget(impCap); } - Debug.Assert(impCap != null); return impCap; } else @@ -1296,7 +1307,6 @@ namespace Capnp.Rpc var newCap = new ImportedCapability(this, capDesc.ThirdPartyHosted.VineId); rcv = new RefCounted>( new WeakReference(newCap)); - Debug.Assert(newCap != null); return newCap; } @@ -1307,7 +1317,7 @@ namespace Capnp.Rpc } } - public IList ImportCapTable(Payload.READER payload) + internal IList ImportCapTable(Payload.READER payload) { var list = new List(); @@ -1481,7 +1491,7 @@ namespace Capnp.Rpc readonly ConcurrentBag _inboundEndpoints = new ConcurrentBag(); - internal RpcEndpoint AddEndpoint(IEndpoint outboundEndpoint) + public RpcEndpoint AddEndpoint(IEndpoint outboundEndpoint) { var inboundEndpoint = new RpcEndpoint(this, outboundEndpoint); _inboundEndpoints.Add(inboundEndpoint); @@ -1503,5 +1513,14 @@ namespace Capnp.Rpc _bootstrapCap = value; } } + + /// + /// Sets the bootstrap capability. It must be an object which implements a valid capability interface + /// (). + /// + public object Main + { + set { BootstrapCap = Skeleton.GetOrCreateSkeleton(value, false); } + } } } \ No newline at end of file diff --git a/Capnp.Net.Runtime/Rpc/RpcProtocolErrorException.cs b/Capnp.Net.Runtime/Rpc/RpcProtocolErrorException.cs new file mode 100644 index 0000000..6b4885d --- /dev/null +++ b/Capnp.Net.Runtime/Rpc/RpcProtocolErrorException.cs @@ -0,0 +1,9 @@ +namespace Capnp.Rpc +{ + class RpcProtocolErrorException : System.Exception + { + public RpcProtocolErrorException(string reason): base(reason) + { + } + } +} \ No newline at end of file diff --git a/Capnp.Net.Runtime/Rpc/TcpRpcServer.cs b/Capnp.Net.Runtime/Rpc/TcpRpcServer.cs index dcdb9fb..75ad222 100644 --- a/Capnp.Net.Runtime/Rpc/TcpRpcServer.cs +++ b/Capnp.Net.Runtime/Rpc/TcpRpcServer.cs @@ -325,7 +325,7 @@ namespace Capnp.Rpc /// public object Main { - set { _rpcEngine.BootstrapCap = Skeleton.GetOrCreateSkeleton(value, false); } + set { _rpcEngine.Main = value; } } ///