diff --git a/Capnp.Net.Runtime.Tests/Dtbdct.cs b/Capnp.Net.Runtime.Tests/Dtbdct.cs index fa48434..57ae39c 100644 --- a/Capnp.Net.Runtime.Tests/Dtbdct.cs +++ b/Capnp.Net.Runtime.Tests/Dtbdct.cs @@ -158,5 +158,23 @@ namespace Capnp.Net.Runtime.Tests { NewDtbdctTestbed().RunTest(Testsuite.ImportReceiverCanceled); } + + [TestMethod] + public void ButNoTailCall() + { + NewDtbdctTestbed().RunTest(Testsuite.ButNoTailCall); + } + + [TestMethod] + public void SecondIsTailCall() + { + NewDtbdctTestbed().RunTest(Testsuite.SecondIsTailCall); + } + + [TestMethod] + public void ReexportSenderPromise() + { + NewDtbdctTestbed().RunTest(Testsuite.ReexportSenderPromise); + } } } diff --git a/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs b/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs index e26b84d..795d451 100644 --- a/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs +++ b/Capnp.Net.Runtime.Tests/Mock/TestCapImplementations.cs @@ -681,6 +681,72 @@ namespace Capnp.Net.Runtime.Tests.GenImpls } } } + + class TestTailCallerImpl3 : ITestTailCaller + { + public TestTailCallerImpl3() + { + } + + public void Dispose() + { + } + + public Task Foo(int i, ITestTailCallee callee, CancellationToken cancellationToken_) + { + using (callee) + { + var task1 = callee.Foo(i, "from TestTailCaller 1", cancellationToken_); + + async void FinishTask() + { + var r = await task1; + r.C.Dispose(); + } + + FinishTask(); + + var task2 = callee.Foo(i, "from TestTailCaller 2", cancellationToken_); + + async void AssertIsTailCall() + { + try + { + await task2; + Assert.Fail("Not a tail call"); + } + catch (TailCallNoDataException) + { + } + } + + AssertIsTailCall(); + + return task2; + } + } + } + + class TestTailCallerImpl4 : ITestTailCaller + { + public TestTailCallerImpl4() + { + } + + public void Dispose() + { + } + + public async Task Foo(int i, ITestTailCallee callee, CancellationToken cancellationToken_) + { + await Task.Yield(); + + using (callee) + { + return await callee.Foo(i, "from TestTailCaller", cancellationToken_); + } + } + } #endregion TestTailCaller #region TestTailCallee @@ -993,7 +1059,7 @@ namespace Capnp.Net.Runtime.Tests.GenImpls { if (_echoCounter++ < 20) { - return Task.FromResult(((Proxy)cap).Cast(true).Echo(cap).Eager()); + return Task.FromResult(((Proxy)cap).Cast(false).Echo(cap).Eager()); } else { @@ -1056,6 +1122,192 @@ namespace Capnp.Net.Runtime.Tests.GenImpls } } + class TestMoreStuffImpl4 : ITestMoreStuff, ITestCallOrder + { + readonly TaskCompletionSource _heldCap = new TaskCompletionSource(); + + public Task CallFoo(ITestInterface cap, CancellationToken cancellationToken_ = default) + { + using (cap) + { + return cap.Foo(123, true); + } + } + + public Task CallFooWhenResolved(ITestInterface Cap, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task CallHeld(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public async void Dispose() + { + using (var cap = await _heldCap.Task) + { + } + } + + public Task Echo(ITestCallOrder cap, CancellationToken cancellationToken_ = default) + { + using (var target = ((Proxy)cap).Cast(false)) + { + return Task.FromResult(target.Echo(cap).Eager()); + } + } + + public Task ExpectCancel(ITestInterface Cap, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + uint _counter; + + public Task GetCallSequence(uint expected, CancellationToken cancellationToken_ = default) + { + Assert.AreEqual(_counter, expected); + return Task.FromResult(_counter++); + } + + public Task GetEnormousString(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task GetHandle(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task GetHeld(CancellationToken cancellationToken_ = default) + { + return Task.FromResult(_heldCap.Task.Eager(true)); + } + + public Task GetNull(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task Hold(ITestInterface Cap, CancellationToken cancellationToken_ = default) + { + _heldCap.SetResult(Cap); + return Task.CompletedTask; + } + + public Task<(string, string)> MethodWithDefaults(string A, uint B, string C, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task MethodWithNullDefault(string A, ITestInterface B, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task NeverReturn(ITestInterface Cap, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + } + + class TestMoreStuffImpl5 : ITestMoreStuff, ITestCallOrder + { + readonly TaskCompletionSource _heldCap = new TaskCompletionSource(); + + public Task CallFoo(ITestInterface cap, CancellationToken cancellationToken_ = default) + { + using (cap) + { + return cap.Foo(123, true); + } + } + + public Task CallFooWhenResolved(ITestInterface Cap, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task CallHeld(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public async void Dispose() + { + using (var cap = await _heldCap.Task) + { + } + } + + TaskCompletionSource _echoEnabled = new TaskCompletionSource(); + + public void EnableEcho() => _echoEnabled.SetResult(0); + + public async Task Echo(ITestCallOrder cap, CancellationToken cancellationToken_ = default) + { + await _echoEnabled.Task; + return cap; + } + + public Task ExpectCancel(ITestInterface Cap, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + uint _counter; + + public Task GetCallSequence(uint expected, CancellationToken cancellationToken_ = default) + { + Assert.AreEqual(_counter, expected); + return Task.FromResult(_counter++); + } + + public Task GetEnormousString(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task GetHandle(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task GetHeld(CancellationToken cancellationToken_ = default) + { + return Task.FromResult(_heldCap.Task.Eager(true)); + } + + public Task GetNull(CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task Hold(ITestInterface Cap, CancellationToken cancellationToken_ = default) + { + _heldCap.SetResult(Cap); + return Task.CompletedTask; + } + + public Task<(string, string)> MethodWithDefaults(string A, uint B, string C, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task MethodWithNullDefault(string A, ITestInterface B, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + + public Task NeverReturn(ITestInterface Cap, CancellationToken cancellationToken_ = default) + { + throw new NotImplementedException(); + } + } + #endregion TestMoreStuff #region TestHandle diff --git a/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs b/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs index f9483bf..6d41fc9 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcAdvancedStuff.cs @@ -261,20 +261,29 @@ namespace Capnp.Net.Runtime.Tests { var echoTask = main.Echo(Proxy.Share(main)); Assert.IsTrue(echoTask.Wait(MediumNonDbgTimeout)); - var echo = echoTask.Result; - var list = new Task[1000]; - for (uint i = 0; i < list.Length; i++) + using (var echo = echoTask.Result) { - list[i] = echo.GetCallSequence(i); - } - Assert.IsTrue(Task.WaitAll(list, MediumNonDbgTimeout)); - for (uint i = 0; i < list.Length; i++) - { - Assert.AreEqual(i, list[i].Result); + var list = new Task[1000]; + for (uint i = 0; i < list.Length; i++) + { + list[i] = echo.GetCallSequence(i); + } + Assert.IsTrue(Task.WaitAll(list, MediumNonDbgTimeout)); + for (uint i = 0; i < list.Length; i++) + { + Assert.AreEqual(i, list[i].Result); + } } } } } } + + + [TestMethod] + public void NoTailCallMt() + { + NewLocalhostTcpTestbed().RunTest(Testsuite.NoTailCallMt); + } } } diff --git a/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs b/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs index e093a27..5886456 100644 --- a/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs +++ b/Capnp.Net.Runtime.Tests/TcpRpcInterop.cs @@ -415,20 +415,22 @@ namespace Capnp.Net.Runtime.Tests var callee = new TestTailCalleeImpl(calleeCallCount); var promise = main.Foo(456, callee, default); - var dependentCall0 = promise.C().GetCallSequence(0, default); + using (var c = promise.C()) + { + var dependentCall0 = c.GetCallSequence(0, default); - Assert.IsTrue(promise.Wait(MediumNonDbgTimeout)); - Assert.AreEqual(456u, promise.Result.I); - Assert.AreEqual("from TestTailCaller", promise.Result.T); + Assert.IsTrue(promise.Wait(MediumNonDbgTimeout)); + Assert.AreEqual(456u, promise.Result.I); + Assert.AreEqual("from TestTailCaller", promise.Result.T); - var dependentCall1 = promise.C().GetCallSequence(0, default); - var dependentCall2 = promise.C().GetCallSequence(0, default); - - AssertOutput(stdout, "foo"); - Assert.IsTrue(dependentCall0.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(dependentCall1.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(dependentCall2.Wait(MediumNonDbgTimeout)); + var dependentCall1 = c.GetCallSequence(0, default); + var dependentCall2 = c.GetCallSequence(0, default); + AssertOutput(stdout, "foo"); + Assert.IsTrue(dependentCall0.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(dependentCall1.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(dependentCall2.Wait(MediumNonDbgTimeout)); + } Assert.AreEqual(1, calleeCallCount.CallCount); } } @@ -523,28 +525,30 @@ namespace Capnp.Net.Runtime.Tests using (var main = client.GetMain()) { var tcs = new TaskCompletionSource(); - var eager = tcs.Task.Eager(true); + using (var eager = tcs.Task.Eager(true)) + { + var request = main.CallFoo(Proxy.Share(eager), default); + AssertOutput(stdout, "callFoo"); + var request2 = main.CallFooWhenResolved(Proxy.Share(eager), default); + AssertOutput(stdout, "callFooWhenResolved"); - var request = main.CallFoo(eager, default); - AssertOutput(stdout, "callFoo"); - var request2 = main.CallFooWhenResolved(eager, default); - AssertOutput(stdout, "callFooWhenResolved"); + var gcs = main.GetCallSequence(0, default); + AssertOutput(stdout, "getCallSequence"); + Assert.IsTrue(gcs.Wait(MediumNonDbgTimeout)); + Assert.AreEqual(2u, gcs.Result); - var gcs = main.GetCallSequence(0, default); - AssertOutput(stdout, "getCallSequence"); - Assert.IsTrue(gcs.Wait(MediumNonDbgTimeout)); - Assert.AreEqual(2u, gcs.Result); + var chainedCallCount = new Counters(); + var tiimpl = new TestInterfaceImpl(chainedCallCount); + tcs.SetResult(tiimpl); - var chainedCallCount = new Counters(); - var tiimpl = new TestInterfaceImpl(chainedCallCount); - tcs.SetResult(tiimpl); + Assert.IsTrue(request.Wait(MediumNonDbgTimeout)); + Assert.IsTrue(request2.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(request.Wait(MediumNonDbgTimeout)); - Assert.IsTrue(request2.Wait(MediumNonDbgTimeout)); + Assert.AreEqual("bar", request.Result); + Assert.AreEqual("bar", request2.Result); + Assert.AreEqual(2, chainedCallCount.CallCount); - Assert.AreEqual("bar", request.Result); - Assert.AreEqual("bar", request2.Result); - Assert.AreEqual(2, chainedCallCount.CallCount); + } AssertOutput(stdout, "fin"); AssertOutput(stdout, "fin"); diff --git a/Capnp.Net.Runtime.Tests/Testsuite.cs b/Capnp.Net.Runtime.Tests/Testsuite.cs index 443beea..c64c0be 100644 --- a/Capnp.Net.Runtime.Tests/Testsuite.cs +++ b/Capnp.Net.Runtime.Tests/Testsuite.cs @@ -749,5 +749,74 @@ namespace Capnp.Net.Runtime.Tests Assert.IsTrue(foo.IsCanceled); } } + + public static void ButNoTailCall(ITestbed testbed) + { + var impl = new TestMoreStuffImpl4(); + using (var main = testbed.ConnectMain(impl)) + { + var peer = new TestMoreStuffImpl5(); + var heldTask = main.Echo(peer); + + testbed.MustComplete(heldTask); + + var r = heldTask.Result as IResolvingCapability; + + peer.EnableEcho(); + + testbed.MustComplete(r.WhenResolved); + + heldTask.Result.Dispose(); + } + } + + public static void SecondIsTailCall(ITestbed testbed) + { + var impl = new TestTailCallerImpl3(); + using (var main = testbed.ConnectMain(impl)) + { + var callee = new TestTailCalleeImpl(new Counters()); + var task = main.Foo(123, callee); + testbed.MustComplete(task); + Assert.AreEqual("from TestTailCaller 2", task.Result.T); + } + } + + public static void NoTailCallMt(ITestbed testbed) + { + var impl = new TestTailCallerImpl4(); + using (var main = testbed.ConnectMain(impl)) + using (var callee = Proxy.Share(new TestTailCalleeImpl(new Counters()))) + { + var tasks = ParallelEnumerable + .Range(0, 1000) + .Select(async i => + { + var r = await main.Foo(i, Proxy.Share(callee)); + Assert.AreEqual((uint)i, r.I); + }) + .ToArray(); + + testbed.MustComplete(tasks); + Assert.IsFalse(tasks.Any(t => t.IsCanceled || t.IsFaulted)); + } + } + + public static void ReexportSenderPromise(ITestbed testbed) + { + var impl = new TestTailCallerImpl(new Counters()); + using (var main = testbed.ConnectMain(impl)) + { + var tcs = new TaskCompletionSource(); + using (var promise = Proxy.Share(tcs.Task.Eager(true))) + { + var task1 = main.Foo(1, Proxy.Share(promise)); + var task2 = main.Foo(2, Proxy.Share(promise)); + var callee = new TestTailCalleeImpl(new Counters()); + tcs.SetResult(callee); + testbed.MustComplete(task1, task2); + } + } + } } } diff --git a/Capnp.Net.Runtime.Tests/Util/TestBase.cs b/Capnp.Net.Runtime.Tests/Util/TestBase.cs index 000b810..e40b7a1 100644 --- a/Capnp.Net.Runtime.Tests/Util/TestBase.cs +++ b/Capnp.Net.Runtime.Tests/Util/TestBase.cs @@ -242,11 +242,16 @@ namespace Capnp.Net.Runtime.Tests { (_server, _client) = SetupClientServerPair(); _client.WhenConnected.Wait(MediumNonDbgTimeout); + Assert.IsTrue(SpinWait.SpinUntil(() => _server.ConnectionCount > 0, MediumNonDbgTimeout)); + var conn = _server.Connections[0]; using (_server) using (_client) { action(this); + + Assert.IsTrue(SpinWait.SpinUntil(() => _client.SendCount == conn.RecvCount, MediumNonDbgTimeout)); + Assert.IsTrue(SpinWait.SpinUntil(() => conn.SendCount == _client.RecvCount, MediumNonDbgTimeout)); } } diff --git a/Capnp.Net.Runtime/Capnp.Net.Runtime.csproj b/Capnp.Net.Runtime/Capnp.Net.Runtime.csproj index b7675b9..6f81750 100644 --- a/Capnp.Net.Runtime/Capnp.Net.Runtime.csproj +++ b/Capnp.Net.Runtime/Capnp.Net.Runtime.csproj @@ -29,7 +29,7 @@ - + DebugFinalizers diff --git a/Capnp.Net.Runtime/Rpc/Impatient.cs b/Capnp.Net.Runtime/Rpc/Impatient.cs index d969ddc..0f4685c 100644 --- a/Capnp.Net.Runtime/Rpc/Impatient.cs +++ b/Capnp.Net.Runtime/Rpc/Impatient.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -11,7 +12,7 @@ namespace Capnp.Rpc public static class Impatient { static readonly ConditionalWeakTable _taskTable = new ConditionalWeakTable(); - static readonly ThreadLocal _askingEndpoint = new ThreadLocal(); + static readonly ThreadLocal> _askingEndpoint = new ThreadLocal>(() => new Stack()); /// /// Attaches a continuation to the given promise and registers the resulting task for pipelining. @@ -171,8 +172,17 @@ namespace Capnp.Rpc internal static IRpcEndpoint? AskingEndpoint { - get => _askingEndpoint.Value; - set { _askingEndpoint.Value = value; } + get => _askingEndpoint.Value!.Count > 0 ? _askingEndpoint.Value.Peek() : null; + } + + internal static void PushAskingEndpoint(IRpcEndpoint endpoint) + { + _askingEndpoint.Value!.Push(endpoint); + } + + internal static void PopAskingEndpoint() + { + _askingEndpoint.Value!.Pop(); } /// diff --git a/Capnp.Net.Runtime/Rpc/PendingQuestion.cs b/Capnp.Net.Runtime/Rpc/PendingQuestion.cs index 5884656..5d6aa50 100644 --- a/Capnp.Net.Runtime/Rpc/PendingQuestion.cs +++ b/Capnp.Net.Runtime/Rpc/PendingQuestion.cs @@ -66,14 +66,6 @@ namespace Capnp.Rpc _inParams = inParams; StateFlags = inParams == null ? State.Sent : State.None; - if (inParams != null) - { - foreach (var cap in inParams.Caps!) - { - cap.AddRef(); - } - } - if (target != null) { target.AddRef(); @@ -282,22 +274,6 @@ namespace Capnp.Rpc return new RemoteAnswerCapability(this, access, proxyTask); } - static void ReleaseCaps(ConsumedCapability? target, SerializerState? inParams) - { - if (inParams != null) - { - foreach (var cap in inParams.Caps!) - { - cap.Release(); - } - } - - if (target != null) - { - target.Release(); - } - } - static void ReleaseOutCaps(DeserializerState outParams) { foreach (var cap in outParams.Caps!) @@ -327,8 +303,8 @@ namespace Capnp.Rpc Debug.Assert(msg.Call!.Target.which != MessageTarget.WHICH.undefined); var call = msg.Call; call.QuestionId = QuestionId; - call.SendResultsTo.which = IsTailCall ? - Call.sendResultsTo.WHICH.Yourself : + call.SendResultsTo.which = IsTailCall ? + Call.sendResultsTo.WHICH.Yourself : Call.sendResultsTo.WHICH.Caller; try @@ -341,7 +317,7 @@ namespace Capnp.Rpc OnException(exception); } - ReleaseCaps(target, inParams); + target?.Release(); } /// diff --git a/Capnp.Net.Runtime/Rpc/Proxy.cs b/Capnp.Net.Runtime/Rpc/Proxy.cs index 419826e..bcfa0be 100644 --- a/Capnp.Net.Runtime/Rpc/Proxy.cs +++ b/Capnp.Net.Runtime/Rpc/Proxy.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using System; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; @@ -25,10 +26,6 @@ namespace Capnp.Rpc return BareProxy.FromImpl(obj).Cast(true); } -#if DebugFinalizers - ILogger Logger { get; } = Logging.CreateLogger(); -#endif - bool _disposedValue = false; /// @@ -56,15 +53,23 @@ namespace Capnp.Rpc return CapabilityReflection.CreateProxy(ConsumedCap) as T; } + ConsumedCapability? _consumedCap; + /// /// Underlying low-level capability /// - protected internal ConsumedCapability? ConsumedCap { get; private set; } + protected internal ConsumedCapability? ConsumedCap => _disposedValue ? + throw new ObjectDisposedException(nameof(Proxy)) : _consumedCap; /// /// Whether is this a broken capability. /// - public bool IsNull => ConsumedCap == null; + public bool IsNull => _consumedCap == null; + + /// + /// Whether was called on this Proxy. + /// + public bool IsDisposed => _disposedValue; static async void DisposeCtrWhenReturned(CancellationTokenRegistration ctr, IPromisedAnswer answer) { @@ -134,12 +139,12 @@ namespace Capnp.Rpc if (cap == null) return; - ConsumedCap = cap; + _consumedCap = cap; cap.AddRef(); #if DebugFinalizers - if (ConsumedCap != null) - ConsumedCap.OwningProxy = this; + if (_consumedCap != null) + _consumedCap.OwningProxy = this; #endif } @@ -166,14 +171,14 @@ namespace Capnp.Rpc { if (disposing) { - ConsumedCap?.Release(); + _consumedCap?.Release(); } else { // When called from the Finalizer, we must not throw. // But when reference counting goes wrong, ConsumedCapability.Release() will throw an InvalidOperationException. // The only option here is to suppress that exception. - try { ConsumedCap?.Release(); } + try { _consumedCap?.Release(); } catch { } } @@ -187,7 +192,7 @@ namespace Capnp.Rpc ~Proxy() { #if DebugFinalizers - Logger?.LogWarning($"Caught orphaned Proxy, created from here: {CreatorStackTrace}."); + Debugger.Log(0, "DebugFinalizers", $"Caught orphaned Proxy, created from here: {CreatorStackTrace}."); #endif Dispose(false); diff --git a/Capnp.Net.Runtime/Rpc/RpcEngine.cs b/Capnp.Net.Runtime/Rpc/RpcEngine.cs index 5a55cd6..c271d53 100644 --- a/Capnp.Net.Runtime/Rpc/RpcEngine.cs +++ b/Capnp.Net.Runtime/Rpc/RpcEngine.cs @@ -66,7 +66,7 @@ namespace Capnp.Rpc Dismissed } - static readonly ThreadLocal _tailCall = new ThreadLocal(); + static readonly ThreadLocal _deferredCall = new ThreadLocal(); static readonly ThreadLocal _canDeferCalls = new ThreadLocal(); ILogger Logger { get; } = Logging.CreateLogger(); @@ -81,6 +81,7 @@ namespace Capnp.Rpc readonly Dictionary _answerTable = new Dictionary(); readonly Dictionary> _pendingDisembargos = new Dictionary>(); readonly object _reentrancyBlocker = new object(); + readonly object _callReturnBlocker = new object(); long _recvCount; long _sendCount; @@ -284,16 +285,8 @@ namespace Capnp.Rpc if (_revExportTable.TryGetValue(providedCapability, out uint id)) { + _exportTable[id].AddRef(); first = false; - - if (_exportTable.TryGetValue(id, out var rc)) - { - rc.AddRef(); - } - else - { - Logger.LogError("Inconsistent export table: Capability with id {0} exists in reverse table only", id); - } } else { @@ -305,7 +298,6 @@ namespace Capnp.Rpc _revExportTable.Add(providedCapability, id); _exportTable.Add(id, new RefCounted(providedCapability)); - first = true; } @@ -407,8 +399,22 @@ namespace Capnp.Rpc } } + void DispatchDeferredCalls() + { + var call = _deferredCall.Value; + _deferredCall.Value = null; + call?.Send(); + } void ProcessCall(Call.READER req) + { + lock (_callReturnBlocker) + { + ProcessCallLocked(req); + } + } + + void ProcessCallLocked(Call.READER req) { Return.WRITER SetupReturn(MessageBuilder mb) { @@ -420,8 +426,10 @@ namespace Capnp.Rpc return ret; } - void ReturnCall(Action why) + void ReturnCallNoCapTable(Action why) { + DispatchDeferredCalls(); + var mb = MessageBuilder.Create(); mb.InitCapTable(); var ret = SetupReturn(mb); @@ -430,7 +438,10 @@ namespace Capnp.Rpc try { - Tx(mb.Frame); + lock (_callReturnBlocker) + { + Tx(mb.Frame); + } } catch (RpcException exception) { @@ -473,7 +484,7 @@ namespace Capnp.Rpc { Debug.Fail("Either answer or counter question must be present"); } - else if (aorcq.Answer != null || aorcq.Counterquestion != _tailCall.Value) + else if (aorcq.Answer != null || aorcq.Counterquestion != _deferredCall.Value) { var results = aorcq.Answer ?? (DynamicSerializerState)(await aorcq.Counterquestion!.WhenReturned); var ret = SetupReturn(results.MsgBuilder!); @@ -484,12 +495,13 @@ namespace Capnp.Rpc ret.which = Return.WHICH.Results; ret.Results!.Content = results.Rewrap(); ret.ReleaseParamCaps = releaseParamCaps; + DispatchDeferredCalls(); ExportCapTableAndSend(results, ret.Results); pendingAnswer.CapTable = ret.Results.CapTable; break; case Call.sendResultsTo.WHICH.Yourself: - ReturnCall(ret2 => + ReturnCallNoCapTable(ret2 => { ret2.which = Return.WHICH.ResultsSentElsewhere; ret2.ReleaseParamCaps = releaseParamCaps; @@ -499,11 +511,11 @@ namespace Capnp.Rpc } else if (aorcq.Counterquestion != null) { - _tailCall.Value = null; + _deferredCall.Value = null; aorcq.Counterquestion.IsTailCall = true; aorcq.Counterquestion.Send(); - ReturnCall(ret2 => + ReturnCallNoCapTable(ret2 => { ret2.which = Return.WHICH.TakeFromOtherQuestion; ret2.TakeFromOtherQuestion = aorcq.Counterquestion.QuestionId; @@ -513,7 +525,7 @@ namespace Capnp.Rpc } catch (TaskCanceledException) { - ReturnCall(ret => + ReturnCallNoCapTable(ret => { ret.which = Return.WHICH.Canceled; ret.ReleaseParamCaps = releaseParamCaps; @@ -521,7 +533,7 @@ namespace Capnp.Rpc } catch (System.Exception exception) { - ReturnCall(ret => + ReturnCallNoCapTable(ret => { ret.which = Return.WHICH.Exception; ret.Exception!.Reason = exception.Message; @@ -543,7 +555,7 @@ namespace Capnp.Rpc } finally { - ReturnCall(ret => + ReturnCallNoCapTable(ret => { ret.which = Return.WHICH.ResultsSentElsewhere; ret.ReleaseParamCaps = releaseParamCaps; @@ -600,7 +612,7 @@ namespace Capnp.Rpc } _canDeferCalls.Value = true; - Impatient.AskingEndpoint = this; + Impatient.PushAskingEndpoint(this); try { @@ -683,10 +695,8 @@ namespace Capnp.Rpc finally { _canDeferCalls.Value = false; - Impatient.AskingEndpoint = null; - var call = _tailCall.Value; - _tailCall.Value = null; - call?.Send(); + Impatient.PopAskingEndpoint(); + DispatchDeferredCalls(); } } @@ -1419,8 +1429,8 @@ namespace Capnp.Rpc if (_canDeferCalls.Value) { - _tailCall.Value?.Send(); - _tailCall.Value = question; + DispatchDeferredCalls(); + _deferredCall.Value = question; } else { diff --git a/Capnp.Net.Runtime/Rpc/Vine.cs b/Capnp.Net.Runtime/Rpc/Vine.cs index 3572f6c..1c4f29c 100644 --- a/Capnp.Net.Runtime/Rpc/Vine.cs +++ b/Capnp.Net.Runtime/Rpc/Vine.cs @@ -76,7 +76,12 @@ namespace Capnp.Rpc protected override void Dispose(bool disposing) { - Proxy.Dispose(); + if (disposing) + Proxy.Dispose(); + else + try { Proxy.Dispose(); } + catch { } + base.Dispose(disposing); } }